本文介绍了Tensorflow Dataset API 在完成一个 epoch 后恢复 Iterator的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有 190 个特征和标签,我的批量大小是 20,但经过 9 次迭代 tf.reshape 返回异常 要重塑的输入是一个具有 21 个值的张量,但请求的形状有60,我知道这是由于 Iterator.get_next() 造成的.我如何恢复我的迭代器,以便它再次从头开始提供批次服务?

I have 190 features and labels,My batch size is 20 but after 9 iterations tf.reshape is returning exception Input to reshape is a tensor with 21 values,but the requested shape has 60 and i know it is due to Iterator.get_next().How do i restore my Iterator so that it will again start serving batches from the beginning?

推荐答案

如果你想重启一个 tf.data.Iterator 从它的 Dataset 开始,考虑使用 initializable 迭代器,它有您可以运行以重新初始化迭代器的操作:

If you want to restart a tf.data.Iterator from the beginning of its Dataset, consider using an initializable iterator, which has an operation you can run to re-initialize the iterator:

dataset = ...  # A `tf.data.Dataset` instance.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

train_op = ...  # Something that depends on `next_element`.

for _ in range(NUM_EPOCHS):
  # Initialize the iterator at the beginning of `dataset`.
  sess.run(iterator.initializer)

  # Loop over the examples in `iterator`, running `train_op`.
  try:
    while True:
      sess.run(train_op)

  except tf.errors.OutOfRangeError:  # Thrown at the end of the epoch.
    pass

  # Perform any per-epoch computations here.

有关不同类型的 Iterator 的更多详细信息,请参阅 tf.data 程序员指南.

For more details on the different kinds of Iterator, see the tf.data programmer's guide.

这篇关于Tensorflow Dataset API 在完成一个 epoch 后恢复 Iterator的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-23 22:21