我正在尝试使用LSTM单元和Tensorflow创建文本生成神经网络。我正在为网络训练时间主要格式的句子[time_steps,batch_size,input_size],我希望每个时间步都可以预测序列中的下一个单词。直到时间步长为止,序列中都填充有空值,并且一个单独的占位符包含批次中每个序列的长度。

关于通过时间进行反向传播的概念有很多信息,但是对于可变长度序列成本计算,我无法找到关于张量流中实际实现的任何信息。由于序列的末尾是填充的,因此我假设我不想计算填充部分的成本。因此,我需要一种将输出从第一个输出裁剪到序列末尾的方法。

这是我目前拥有的代码:

    outputs = []
    states = []
    cost = 0
    for i in range(time_steps+1):
        output, state = cell(X[i], state)
        z1 = tf.matmul(output, dec_W1) + dec_b1
        a1 = tf.nn.sigmoid(z1)
        z2 = tf.matmul(a1, dec_W2) + dec_b2
        a2 = tf.nn.softmax(z2)
        outputs.append(a2)
        states.append(state)
        #== calculate cost
        cost = cost + tf.nn.softmax_cross_entropy_with_logits(logits=z2, labels=y[i])
    optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)


此代码无需可变长度序列即可工作。但是,如果我在最后添加了填充值,那么它也会计算填充部分的成本,这没有多大意义。

如何仅计算序列长度上限之前的输出成本?

最佳答案

解决了!

在研究了许多示例(大多数示例是在Keras等较高级别的框架中)后,我发现您必须创建一个蒙版!回想起来似乎很简单。

这是创建1和0的掩码,然后将其与矩阵逐元素相乘的代码(这将是成本值)

x = tf.placeholder(tf.float32)
seq = tf.placeholder(tf.int32)

def mask_by_length(input_matrix, length):
    '''
        Input matrix is a 2d tensor [batch_size, time_steps]
        length is a 1d tensor
        length refers to the length of input matrix axis 1
    '''
    length_transposed = tf.expand_dims(length, 1)

    # Create range in order to compare length to
    range = tf.range(tf.shape(input_matrix)[1])
    range_row = tf.expand_dims(range, 0)

    # Use the logical operations to create a mask
    mask = tf.less(range_row, length_transposed)

    # cast boolean to int to finalize mask
    mask_result = tf.cast(mask, dtype=tf.float32)

    # Element-wise multiplication to cancel out values in the mask
    result = tf.multiply(mask_result, input_matrix)

    return result

mask_values = mask_by_length(x, seq)


输入值(主要时间)[time_steps,batch_size]

[[0.71,0.22,1.42,-0.28,0.99]
 [0.41、2.24、0.09、0.74、0.65]]

序列值[batch_size]

[2,3]

输出(主要时间)[time_steps,batch_size]

[[0.71,0.22,0,0,0,]
 [0.41,2.24,0.09,0,0,]]

关于python - 计算每个时间步长的可变长度输出的成本,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/48040685/

10-12 00:28
查看更多