问题描述
默认情况下,函数dynamic_rnn
在每个时间点仅输出隐藏状态(称为m
),可以通过以下方式获取该信息:
By default, function dynamic_rnn
outputs only hidden states (known as m
) for each time point which can be obtained as follows:
cell = tf.contrib.rnn.LSTMCell(100)
rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
inputs=inputs,
sequence_length=sequence_lengths,
dtype=tf.float32)
是否还有一种方法可以获取中间(不是最终)单元状态(c
)?
Is there a way get intermediate (not final) cell states (c
) in addition?
tensorflow
贡献者提到用细胞包装纸:
A tensorflow
contributor mentions that it can be done with a cell wrapper:
class Wrapper(tf.nn.rnn_cell.RNNCell):
def __init__(self, inner_cell):
super(Wrapper, self).__init__()
self._inner_cell = inner_cell
@property
def state_size(self):
return self._inner_cell.state_size
@property
def output_size(self):
return (self._inner_cell.state_size, self._inner_cell.output_size)
def call(self, input, state)
output, next_state = self._inner_cell(input, state)
emit_output = (next_state, output)
return emit_output, next_state
但是,它似乎不起作用.有什么想法吗?
However, it doesn't seem to work. Any ideas?
推荐答案
建议的解决方案对我有用,但是Layer.call
方法规范更为通用,因此以下Wrapper
对于API更改应更可靠.你的这个:
The proposed solution works for me, but Layer.call
method spec is more general, so the following Wrapper
should be more robust to API changes. Thy this:
class Wrapper(tf.nn.rnn_cell.RNNCell):
def __init__(self, inner_cell):
super(Wrapper, self).__init__()
self._inner_cell = inner_cell
@property
def state_size(self):
return self._inner_cell.state_size
@property
def output_size(self):
return (self._inner_cell.state_size, self._inner_cell.output_size)
def call(self, input, *args, **kwargs):
output, next_state = self._inner_cell(input, *args, **kwargs)
emit_output = (next_state, output)
return emit_output, next_state
这是测试:
n_steps = 2
n_inputs = 3
n_neurons = 5
X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False))
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
print(outputs, states)
X_batch = np.array([
# t = 0 t = 1
[[0, 1, 2], [9, 8, 7]], # instance 0
[[3, 4, 5], [0, 0, 0]], # instance 1
[[6, 7, 8], [6, 5, 4]], # instance 2
[[9, 0, 1], [3, 2, 1]], # instance 3
])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
outputs_val = outputs[0].eval(feed_dict={X: X_batch})
print(outputs_val)
返回的outputs
是(?, 2, 10)
和(?, 2, 5)
张量的元组,它们都是LSTM状态和输出.请注意,我使用的是LSTMCell
的分级"版本,而不是tf.contrib.rnn
.另请注意state_is_tuple=True
以避免与LSTMStateTuple
接触.
Returned outputs
is the tuple of (?, 2, 10)
and (?, 2, 5)
tensors, which are all LSTM states and outputs. Note that I'm using the "graduated" version of LSTMCell
, from tf.nn.rnn_cell
package, not tf.contrib.rnn
. Also note state_is_tuple=True
to avoid dealing with LSTMStateTuple
.
这篇关于Tensorflow:如何使用dynamic_rnn从LSTMCell获得中间单元状态(c)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!