目前,我试图在Tensorflow while循环中实现所有的训练,但是Tensorflow dataset API的迭代器有问题。
通常,当调用sess.run()时,Iterator.get_next()前进到下一个元素。
但是,我需要在一次运行中前进到下一个元素。我该怎么做?
下面的小例子说明了我的问题:
import tensorflow as tf
import numpy as np
def for_loop(condition, modifier, body_op, idx=0):
idx = tf.convert_to_tensor(idx)
def body(i):
with tf.control_dependencies([body_op(i)]):
return [modifier(i)]
# do the loop:
loop = tf.while_loop(condition, body, [idx])
return loop
x = np.arange(10)
data = tf.data.Dataset.from_tensor_slices(x)
data = data.repeat()
iterator = data.make_initializable_iterator()
smpl = iterator.get_next()
loop = for_loop(
condition=lambda i: tf.less(i, 5),
modifier=lambda i: tf.add(i, 1),
body_op=lambda i: tf.Print(smpl, [smpl], message="This is sample: ")
)
sess = tf.InteractiveSession()
sess.run(iterator.initializer)
sess.run(loop)
输出:
This is sample: [0]
This is sample: [0]
This is sample: [0]
This is sample: [0]
This is sample: [0]
我总是得到完全相同的元素。
最佳答案
每次想“在一次运行中迭代”时,都需要调用iterator.get_next()
。
例如,在玩具示例中,只需将body_op
替换为:
body_op=lambda i: tf.Print(i, [iterator.get_next()], message="This is sample: ")
# This is sample: [0]
# This is sample: [1]
# This is sample: [2]
# This is sample: [3]
# This is sample: [4]
关于python - tf.data.Iterator.get_next():如何在tf.while_loop中前进?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/50237486/