我正在建立一个RNN模型,其中init_state可能来自以下两种情况之一。 1)从前一个时间步输出状态通过feed_dict馈入的静态init_state。 2)变量的某些功能,我称之为得分。
init_state = cell.zero_state(batch,tf.float32)
with tf.name_scope('hidden1'):
weights_h1 = tf.Variable(
tf.truncated_normal([T, cells_dim],
stddev=1.0 / np.sqrt(T)),
name='weights')
biases_h1 = tf.Variable(tf.zeros([cells_dim]),
name='biases')
hidden1 = tf.nn.relu(tf.matmul(score, weights_h1) + biases_h1)
init_state2 = tf.cond(is_start, lambda: hidden1, lambda: init_state)
然后将init_state2用作static_rnn的输入,static_rnn最终用于计算损失和train_op。当is_start为False时,我希望train_op对weights_h1没有影响。但是,每次更新后权重都会改变。任何帮助是极大的赞赏。
最佳答案
这应该工作:
def return_init_state():
init_state = cell.zero_state(batch,tf.float32)
return init_state
def return_hidden_1():
with tf.name_scope('hidden1'):
weights_h1 = tf.Variable(
tf.truncated_normal([T, cells_dim],
stddev=1.0 / np.sqrt(T)),
name='weights')
biases_h1 = tf.Variable(tf.zeros([cells_dim]),
name='biases')
hidden1 = tf.nn.relu(tf.matmul(score, weights_h1) + biases_h1)
return hidden1
init_state2 = tf.cond(is_start, lambda: return_hidden_1, lambda: return_init_state)
注意在
tf.cond
上下文中如何调用这些方法。因此,无论创建什么op,都将在tf.cond
的上下文内。否则,根据您的情况,操作将以两种方式运行。关于tensorflow - Tensorflow cond不会在假分支上停止渐变,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/45391008/