我想训练一个可以逐个字符生成文本的神经网络。经过一些研究,我决定将LSTM网络用于此任务。

我的输入的结构如下:
我有一个充满文本的文件(大约90,000,000个字符),我将其切成50个字符的重叠序列。考虑以下示例语句:

The quick brown fox jumps over the lazy dog

我将文本分成多个序列:

The quick brown

he quick brown_

e quick brown f

_quick brown fo

quick brown fox

我添加了下划线,这些空格将不会显示在这些地方...

这些将是我输入数据的时间步骤。输出将是每个序列之后的下一个字符,因此上述序列的_, f, o, x and _

字符在向量中一键编码,字典中所有字符的长度,因此如果字母由A B C D组成,则字符C将表示为[0 0 1 0]

因为我无法一次将所有矢量化的文本放入内存中,所以我将其分为几批,只包含少量生成的序列供网络训练。

这样我得到我的输入占位符:

x = tf.placeholder(tf.float32, [batch_size, time_steps, char_size]


在下面的示例代码中,我使用128的batch_size,50的time_steps和48的char_size表示具有50个大写和小写字母的标准字母。

传递给num_unitsBasicLSTMCell也被任意选择为256(以下是有关我的超参数的一些教程)

这是代码:

import tensorflow as tf

batch_size = 128
time_steps = 50
char_size = 50

num_units = 256

sess = tf.InteractiveSession()

X = tf.placeholder(tf.float32, [batch_size, time_steps, char_size])

cell = tf.contrib.rnn.BasicLSTMCell(num_units)
cell = tf.contrib.rnn.MultiRNNCell([cell] * 2, state_is_tuple=True)

output, state = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)


这是错误消息:

Traceback (most recent call last):
  File ".\issue.py", line 16, in <module>
    output, state = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn.py", line 598, in dynamic_rnn
    dtype=dtype)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn.py", line 761, in _dynamic_rnn_loop
    swap_memory=swap_memory)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2775, in while_loop
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2604, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2554, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn.py", line 746, in _time_step
    (output, new_state) = call_cell()
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn.py", line 732, in <lambda>
    call_cell = lambda: cell(input_t, state)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 180, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\layers\base.py", line 450, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 938, in call
    cur_inp, new_state = cell(cur_inp, cur_state)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 180, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\layers\base.py", line 450, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 401, in call
    concat = _linear([inputs, h], 4 * self._num_units, True)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 1039, in _linear
    initializer=kernel_initializer)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1065, in get_variable
    use_resource=use_resource, custom_getter=custom_getter)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 962, in get_variable
    use_resource=use_resource, custom_getter=custom_getter)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 360, in get_variable
    validate_shape=validate_shape, use_resource=use_resource)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 1405, in wrapped_custom_getter
    *args, **kwargs)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 183, in _rnn_get_variable
    variable = getter(*args, **kwargs)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 183, in _rnn_get_variable
    variable = getter(*args, **kwargs)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 352, in _true_getter
    use_resource=use_resource)
  File "C:\Users\uidq6096\AppData\Local\Programs\Python\Python35\lib\site-packages\tensorflow\python\ops\variable_scope.py", line 669, in _get_single_variable
    found_var.get_shape()))
ValueError: Trying to share variable rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel, but specified shape (512, 1024) and found shape (306, 1024).


在过去的几天里,我为此一直在挣扎,我在这里想念什么?

最佳答案

循环初始化多个单元格,而不使用[cell] * n表示法:

cells = []
for _ in range(n):
    cells.append(tf.contrib.rnn.BasicLSTMCell(num_units))  # build list of cells
cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)  # pass your list of cells
output, state = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)


否则,它基本上是尝试多次使用同一单元,但尺寸无法计算。我相信在1.0版本中,此行为已更改。您曾经能够摆脱原来的语法;现在您必须使用它。

关于python - 尝试在TensorFlow中使用BasicLSTMCell和dynamic_rnn建立网络时发生ValueError,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/47205248/

10-12 19:35