tf.estimator
input_fn
的签名可能看起来像这样:
def input_fn(files:list, params:dict):
dataset = tf.data.TFRecordDataset(files)
.map(lambda record: parse_record_fn(record))
if params['mode'] == 'train':
# train specific things
# ...
这样的定义允许然后构造它们的所有
input_fn
,如下所示:train_fn = lambda: input_fn(files['training_set'], {**params, **{"mode": "train"}})
valid_fn = lambda: input_fn(files['validation_set'], {**params, **{"mode": "eval"}})
test_fn = lambda: input_fn(files['test_set'], {**params, **{"mode": "test"}})
train_spec = tf.estimator.TrainSpec(input_fn=train_fn, ...)
eval_spec = tf.estimator.EvalSpec(input_fn=valid_fn, ...)
我的问题是,如何更改
input_fn
签名以允许基于时代的变化。我知道这可能会带来瓶颈,但是如果我可以做以下事情会很好:
def input_fn(...):
# see above
epoch = params["epoch"]
if epoch % 100 == 0:
# modify or make a new dataset
# ...
return dataset.make_one_shot_iterator().get_next()
关键是确保
input_fn
仍与以下设备兼容:tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
最佳答案
我不知道有提供epoch
数字作为参数的任何选项。
就是说,根据定义,纪元是输入函数的功能,因此我们应该只能够处理输入函数中的所有内容,而不是完全可以访问训练参数。因此,我认为您只要稍微摆弄一下就可以实现所需的功能。
例如,如果我有2个数据集:ds1
和ds2
说并且想在数字不能被100整除时使用ds1
,那么我可以通过执行以下操作来创建新的数据集:
dataset = ds1.repeat(99).concatenate(ds2)
由于默认情况下数据集是延迟加载的,因此我不必担心内存的问题(我不会将100倍的数据加载到内存中)。
显然,这确实对数据集的大小有影响,因此您需要考虑在eval操作/回调等之间进行操作的策略,但这应该很容易进行调整。
关于python - Tensorflow 1.10+:将纪元传递给估计器input_fn?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/56594469/