本文介绍了带有 TensorArray 的 TensorFlow while 循环的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
import tensorflow as tf
B = 3
D = 4
T = 5
tf.reset_default_graph()
xs = tf.placeholder(shape=[T, B, D], dtype=tf.float32)
with tf.variable_scope("RNN"):
GRUcell = tf.contrib.rnn.GRUCell(num_units = D)
cell = tf.contrib.rnn.MultiRNNCell([GRUcell])
output_ta = tf.TensorArray(size=T, dtype=tf.float32)
input_ta = tf.TensorArray(size=T, dtype=tf.float32)
input_ta.unstack(xs)
def body(time, output_ta_t, state):
xt = input_ta.read(time)
new_output, new_state = cell(xt, state)
output_ta_t.write(time, new_output)
return (time+1, output_ta_t, new_state)
def condition(time, output, state):
return time < T
time = 0
state = cell.zero_state(B, tf.float32)
time_final, output_ta_final, state_final = tf.while_loop(
cond=condition,
body=body,
loop_vars=(time, output_ta, state))
output_final = output_ta_final.stack()
然后我运行它
x = np.random.normal(size=(T, B, D))
with tf.Session() as sess:
tf.global_variables_initializer().run()
output_final_, state_final_ = sess.run(fetches = [output_final, state_final], feed_dict = {xs:x})
我想了解如何在 TensorFlow while 循环中正确使用 TensorArray.在上面的示例中,我收到以下错误:
I would like to understand how to use TensorArray properly in relation with TensorFlow while loop. In the above sample I get the following error:
InvalidArgumentError: TensorArray RNN/TensorArray_1_21: Could not read from TensorArray index 0 because it has not yet been written to.
我不明白这个无法从 TensorArray 索引 0 读取".我想我通过 unstack 写入 TensorArray input_ta 并在 while 主体中写入 output_ta .我做错了什么?感谢您的帮助.
I do not understand this "could not read from TensorArray index 0". I think I write to the TensorArray input_ta by unstack and to output_ta in the while body. What do I do wrong? Thanks for your help.
推荐答案
解决方案是改变
input_ta.unstack(xs)
到
input_ta = input_ta.unstack(xs)
同样改变
output_ta_t.write(time, new_output)
到
output_ta_t = output_ta_t.write(time, new_output)
通过这两个更改,代码按预期运行.
With these two changes the code runs as expected.
这篇关于带有 TensorArray 的 TensorFlow while 循环的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!