所有官方的google教程都使用一次性迭代器实现所有的估计器api,我找不到任何关于如何使用tf.data的可初始化迭代器和可重新初始化interator而不是一次性迭代器的文档。
有人能告诉我如何使用tf.data的可初始化迭代器和可重新初始化的迭代器在训练数据和测试数据之间切换吗?我们需要运行一个会话来使用feed dict,并在可初始化迭代器中切换数据集,它是一个低级的api,并且很难理解如何使用它作为估计器api体系结构的一部分。
附:我确实发现谷歌提到
注意:目前,一次迭代是唯一一种易于使用估计器的类型。
但是社区里有工作吗?或者我们应该出于某种原因坚持使用一次迭代
最佳答案
若要使用可初始化或可重新初始化的迭代器,必须创建继承自tf.train.sessionrunhook的类。然后这个类可以访问tf.estimator函数使用的会话。
下面是一个快速的示例,您可以根据自己的需要进行调整:
class IteratorInitializerHook(tf.train.SessionRunHook):
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None # Will be set in the input_fn
def after_create_session(self, session, coord):
self.iterator_initializer_func(session)
def get_inputs(X, y):
iterator_initializer_hook = IteratorInitializerHook()
def input_fn():
X_pl = tf.placeholder(X.dtype, X.shape)
y_pl = tf.placeholder(y.dtype, y.shape)
dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
dataset = ...
...
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
feed_dict={X_pl: X, y_pl: y})
return next_example, next_label
return input_fn, iterator_initializer_hook
...
train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)
...
estimator.train(input_fn=train_input_fn,
hooks=[train_iterator_initializer_hook])
estimator.evaluate(input_fn=test_input_fn,
hooks=[test_iterator_initializer_hook])
这是我在blogpost中通过Sebastian Pölsterl找到的代码的修改版本。查看“通过数据集API向估计器提供数据”部分。