将LSTMBlockCell
替换为LSTMBlockFusedCell
将在static\u rnn'中引发typeerror。我使用的是从源代码编译的tensorflow 1.2.0-rc1。
完整错误消息:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-2986e054cb6b> in <module>()
19 enc_cell = tf.contrib.rnn.LSTMBlockFusedCell(rnn_size)
20 enc_layers = tf.contrib.rnn.MultiRNNCell([enc_cell] * num_layers, state_is_tuple=True)
---> 21 _, enc_state = tf.contrib.rnn.static_rnn(enc_layers, enc_input_unstacked, dtype=dtype)
22
23 with tf.variable_scope('decoder'):
~/Virtualenvs/scikit/lib/python3.6/site-packages/tensorflow/python/ops/rnn.py in static_rnn(cell, inputs, initial_state, dtype, sequence_length, scope)
1139
1140 if not _like_rnncell(cell):
-> 1141 raise TypeError("cell must be an instance of RNNCell")
1142 if not nest.is_sequence(inputs):
1143 raise TypeError("inputs must be a sequence")
TypeError: cell must be an instance of RNNCell
要复制的代码:
import tensorflow as tf
batch_size = 8
enc_input_length = 1000
dtype = tf.float32
rnn_size = 8
num_layers = 2
enc_input = tf.placeholder(dtype, shape=[batch_size, enc_input_length, 1])
enc_input_unstacked = tf.unstack(enc_input, axis=1)
with tf.variable_scope('encoder'):
enc_cell = tf.contrib.rnn.LSTMBlockFusedCell(rnn_size)
enc_layers = tf.contrib.rnn.MultiRNNCell([enc_cell] * num_layers)
_, enc_state = tf.contrib.rnn.static_rnn(enc_layers, enc_input_unstacked, dtype=dtype)
_like_rnncell
看起来像:def _like_rnncell(cell):
"""Checks that a given object is an RNNCell by using duck typing."""
conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
hasattr(cell, "zero_state"), callable(cell)]
return all(conditions)
原来
LSTMBlockFusedCell
没有output_size
实现的state_size
和LSTMBlockCell
属性。这是一个bug,还是有一种方法可以使用我缺少的
LSTMBlockFusedCell
。 最佳答案
LSTMBlockFusedCell
继承自FusedRNNCell
而不是RNNCell
,因此不能使用标准的tf.nn.static_rnn
或tf.nn.dynamic_rnn
实例(如错误消息所示)。
但是,在documentation中,可以直接调用单元格以获取完整的输出和状态。
inputs = tf.placeholder(tf.float32, [time_len, batch_size, input_size])
fused_rnn_cell = tf.contrib.rnn.LSTMBlockFusedCell(num_units)
outputs, state = fused_rnn_cell(inputs, dtype=tf.float32)
# outputs shape is (time_len, batch_size, num_units)
# state: LSTMStateTuple where c shape is (batch_size, num_units)
# and h shape is also (batch_size, num_units).
RNNCell
对象调用LSTMBlockFusedCell
internally,这应该相当于正常的lstm循环。另外,请注意,任何
gen_lstm_ops.block_lstm
实例的输入都应该是time-major,这可以通过在调用单元格之前转置张量来完成。