我试图了解CTC实施在TensorFlow中的工作方式。我已经写了一个简单的示例来测试CTC功能,但是由于某些原因,我对某些目标/输入值不了解inf
,所以我敢肯定为什么会这样!
码:
import tensorflow as tf
import numpy as np
# https://github.com/philipperemy/tensorflow-ctc-speech-recognition/blob/master/utils.py
def sparse_tuple_from(sequences, dtype=np.int32):
"""Create a sparse representention of x.
Args:
sequences: a list of lists of type dtype where each element is a sequence
Returns:
A tuple with (indices, values, shape)
"""
indices = []
values = []
for n, seq in enumerate(sequences):
indices.extend(zip([n] * len(seq), range(len(seq))))
values.extend(seq)
indices = np.asarray(indices, dtype=np.int64)
values = np.asarray(values, dtype=dtype)
shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)
return indices, values, shape
batch_size = 1
seq_length = 2
n_labels = 2
seq_len = tf.placeholder(tf.int32, [None])
targets = tf.sparse_placeholder(tf.int32)
logits = tf.constant(np.random.random((batch_size, seq_length, n_labels+1)),dtype=tf.float32) # +1 for the blank label
loss = tf.reduce_mean(tf.nn.ctc_loss(targets, logits, seq_len, time_major = False))
with tf.Session() as sess:
for it in range(10):
rand_target = np.random.randint(n_labels, size=(seq_length))
sample_target = sparse_tuple_from([rand_target])
logitsval = sess.run(logits)
lossval = sess.run(loss, feed_dict={seq_len: [seq_length], targets: sample_target})
print('******* Iter: %d *******'%it)
print('logits:', logitsval)
print('rand_target:', rand_target)
print('rand_sparse_target:', sample_target)
print('loss:', lossval)
print()
样本输出:
******* Iter: 0 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521
******* Iter: 1 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [1 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([1, 1], dtype=int32), array([1, 2]))
loss: inf
******* Iter: 2 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521
******* Iter: 3 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [1 0]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([1, 0], dtype=int32), array([1, 2]))
loss: 1.59766
******* Iter: 4 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 0]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 0], dtype=int32), array([1, 2]))
loss: inf
******* Iter: 5 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521
******* Iter: 6 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [1 0]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([1, 0], dtype=int32), array([1, 2]))
loss: 1.59766
******* Iter: 7 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [1 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([1, 1], dtype=int32), array([1, 2]))
loss: inf
******* Iter: 8 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521
******* Iter: 9 *******
logits: [[[ 0.10151503 0.88581538 0.56466645]
[ 0.76043415 0.52718711 0.01166286]]]
rand_target: [0 0]
rand_sparse_target: (array([[0, 0],
[0, 1]]), array([0, 0], dtype=int32), array([1, 2]))
loss: inf
知道我在那里想念什么!?
最佳答案
仔细查看您的输入文本(rand_target),我确定您看到一些与inf loss值相关的简单模式;-)
对正在发生的事情的简短解释:
CTC通过允许重复每个字符来编码文本,并且还允许在字符之间插入非字符标记(称为“ CTC空白标签”)。取消编码(或解码)只是意味着丢掉重复的字符,然后丢掉所有空格。
举例说明(“ ...”对应于文本,“ ...”对应于编码,“-”对应于空白标签):
“ to”->'tttooo'或't-o'或't-oo'或'to',依此类推...
“ too”->'to-o'或'tttoo --- oo'或'--- t-o-o--',但不是'too'(考虑解码后的'too'的样子)
现在我们已经足够了解为什么某些示例失败了:
输入文字的长度为2
编码的长度是2
如果重复输入字符(例如'11',或作为python列表:[1,1]),则对此进行编码的唯一方法是在两者之间放置一个空格(请考虑对解码的'11'和'1 -1')。但是,编码长度为3。
因此,无法将具有重复字符的长度为2的文本编码为长度为2的编码,因此TF loss实现返回inf
您也可以将编码想象为状态机-参见下图。文本“ 11”可以由从开始状态(最左边的两个状态)开始到最终状态(两个最右边的状态)结束的所有可能路径表示。如您所见,最短的路径是“ 1-1”。
总而言之,您必须考虑为输入文本中的每个重复字符至少插入一个额外的空格。
也许本文有助于理解CTC:https://towardsdatascience.com/3797e43a86c