tf.estimator input_fn的签名可能看起来像这样:

def input_fn(files:list, params:dict):
    dataset = tf.data.TFRecordDataset(files)
                .map(lambda record: parse_record_fn(record))

    if params['mode'] == 'train':
        # train specific things
    # ...


这样的定义允许然后构造它们的所有input_fn,如下所示:

train_fn = lambda: input_fn(files['training_set'], {**params, **{"mode": "train"}})
valid_fn = lambda: input_fn(files['validation_set'], {**params, **{"mode": "eval"}})
test_fn  = lambda: input_fn(files['test_set'],  {**params, **{"mode": "test"}})


train_spec = tf.estimator.TrainSpec(input_fn=train_fn, ...)
eval_spec  = tf.estimator.EvalSpec(input_fn=valid_fn,  ...)


我的问题是,如何更改input_fn签名以允许基于时代的变化。我知道这可能会带来瓶颈,但是如果我可以做以下事情会很好:


def input_fn(...):
    # see above

    epoch = params["epoch"]
    if epoch % 100 == 0:
        # modify or make a new dataset

    # ...
    return dataset.make_one_shot_iterator().get_next()


关键是确保input_fn仍与以下设备兼容:

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

最佳答案

我不知道有提供epoch数字作为参数的任何选项。

就是说,根据定义,纪元是输入函数的功能,因此我们应该只能够处理输入函数中的所有内容,而不是完全可以访问训练参数。因此,我认为您只要稍微摆弄一下就可以实现所需的功能。

例如,如果我有2个数据集:ds1ds2说并且想在数字不能被100整除时使用ds1,那么我可以通过执行以下操作来创建新的数据集:

dataset = ds1.repeat(99).concatenate(ds2)


由于默认情况下数据集是延迟加载的,因此我不必担心内存的问题(我不会将100倍的数据加载到内存中)。

显然,这确实对数据集的大小有影响,因此您需要考虑在eval操作/回调等之间进行操作的策略,但这应该很容易进行调整。

关于python - Tensorflow 1.10+:将纪元传递给估计器input_fn?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/56594469/

10-12 21:15