问题描述
我正在尝试使用 tf.scatter_update()
更新 tf.while_loop()
内的 tf.Variable
.但是,结果是初始值而不是更新值.这是我正在尝试做的示例代码:
I am trying to update a tf.Variable
inside a tf.while_loop()
, using tf.scatter_update()
. However, the result is the initial value instead of the updated value. Here is the sample code of what I am trying to do:
from __future__ import print_function
import tensorflow as tf
def cond(sequence_len, step):
return tf.less(step,sequence_len)
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin,1,step,use_locking=None)
tf.get_variable_scope().reuse_variables()
return (sequence_len, step+1)
with tf.Graph().as_default():
sess = tf.Session()
step = tf.constant(0)
sequence_len = tf.constant(10)
_,step, = tf.while_loop(cond,
body,
[sequence_len, step],
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name=None)
begin = tf.get_variable("begin",[3],dtype=tf.int32)
init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([begin,step]))
结果为:[array([0, 0, 0], dtype=int32), 10]
.但是,我认为结果应该是 [0, 0, 10]
.我在这里做错了吗?
The result is: [array([0, 0, 0], dtype=int32), 10]
. However, I think the result should be [0, 0, 10]
. Am I doing something wrong here?
推荐答案
这里的问题是循环体中没有任何内容依赖于您的 tf.scatter_update()
操作,因此它永远不会被执行.使其工作的最简单方法是添加对返回值更新的控制依赖:
The problem here is that nothing in the loop body depends on your tf.scatter_update()
op, so it is never executed. The easiest way to make it work is to add a control dependency on the update to the return values:
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin, 1, step, use_locking=None)
tf.get_variable_scope().reuse_variables()
with tf.control_dependencies([begin]):
return (sequence_len, step+1)
请注意,此问题并非 TensorFlow 中的循环所独有.如果您刚刚定义了一个名为 begin
的 tf.scatter_update()
操作,但在其上调用 sess.run()
,或者依赖于它,那么更新将不会发生.当您使用 tf.while_loop()
没有办法直接运行循环体中定义的操作,所以最容易产生副作用的方法是添加控件依赖.
Note that this problem isn't unique to loops in TensorFlow. If you had just defined an tf.scatter_update()
op called begin
but call sess.run()
on it, or something that depends on it, then the update won't happen. When you're using the tf.while_loop()
there's no way to run the operations defined in the loop body directly, so the easiest way to get a side effect is to add a control dependency.
注意最后的结果是[0, 9, 0]
:每次迭代都把当前的step赋值给begin[1]
,最后一次迭代的值当前步骤的值为9
(step == 10
时条件为false).
Note that the final result is [0, 9, 0]
: each iteration assigns the current step to begin[1]
, and in the last iteration the value of the current step is 9
(the condition is false when step == 10
).
这篇关于tf.scatter_update() 如何在 while_loop() 中工作的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!