本文介绍了tf.scatter_update() 如何在 while_loop() 中工作的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试使用 tf.scatter_update() 更新 tf.while_loop() 内的 tf.Variable.但是,结果是初始值而不是更新值.这是我正在尝试做的示例代码:

I am trying to update a tf.Variable inside a tf.while_loop(), using tf.scatter_update(). However, the result is the initial value instead of the updated value. Here is the sample code of what I am trying to do:

from __future__ import print_function

import tensorflow as tf

def cond(sequence_len, step):
    return tf.less(step,sequence_len)

def body(sequence_len, step):

    begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
    begin = tf.scatter_update(begin,1,step,use_locking=None)

    tf.get_variable_scope().reuse_variables()
   return (sequence_len, step+1)

with tf.Graph().as_default():

    sess = tf.Session()
    step = tf.constant(0)
    sequence_len  = tf.constant(10)
    _,step, = tf.while_loop(cond,
                    body,
                    [sequence_len, step],
                    parallel_iterations=10,
                    back_prop=True,
                    swap_memory=False,
                    name=None)

    begin = tf.get_variable("begin",[3],dtype=tf.int32)

    init = tf.initialize_all_variables()
    sess.run(init)

    print(sess.run([begin,step]))

结果为:[array([0, 0, 0], dtype=int32), 10].但是,我认为结果应该是 [0, 0, 10].我在这里做错了吗?

The result is: [array([0, 0, 0], dtype=int32), 10]. However, I think the result should be [0, 0, 10]. Am I doing something wrong here?

推荐答案

这里的问题是循环体中没有任何内容依赖于您的 tf.scatter_update() 操作,因此它永远不会被执行.使其工作的最简单方法是添加对返回值更新的控制依赖:

The problem here is that nothing in the loop body depends on your tf.scatter_update() op, so it is never executed. The easiest way to make it work is to add a control dependency on the update to the return values:

def body(sequence_len, step):
    begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
    begin = tf.scatter_update(begin, 1, step, use_locking=None)
    tf.get_variable_scope().reuse_variables()

    with tf.control_dependencies([begin]):
        return (sequence_len, step+1)

请注意,此问题并非 TensorFlow 中的循环所独有.如果您刚刚定义了一个名为 begintf.scatter_update() 操作,但在其上调用 sess.run() ,或者依赖于它,那么更新将不会发生.当您使用 tf.while_loop() 没有办法直接运行循环体中定义的操作,所以最容易产生副作用的方法是添加控件依赖.

Note that this problem isn't unique to loops in TensorFlow. If you had just defined an tf.scatter_update() op called begin but call sess.run() on it, or something that depends on it, then the update won't happen. When you're using the tf.while_loop() there's no way to run the operations defined in the loop body directly, so the easiest way to get a side effect is to add a control dependency.

注意最后的结果是[0, 9, 0]:每次迭代都把当前的step赋值给begin[1],最后一次迭代的值当前步骤的值为9(step == 10时条件为false).

Note that the final result is [0, 9, 0]: each iteration assigns the current step to begin[1], and in the last iteration the value of the current step is 9 (the condition is false when step == 10).

这篇关于tf.scatter_update() 如何在 while_loop() 中工作的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

09-05 10:18