我想在tensorflow最新版本(1.2)中可视化注意力得分。我在contrib.seq2seq中使用AttentionWrapper构建一个RNNCell,以BasicDecoder作为解码器,然后使用dynamic_decode()逐步生成输出。

如何获得所有步骤的注意力权重?谢谢!

最佳答案

您可以通过在AttentionWrapper定义中设置 alignment_history = True 标志来访问注意力权重。

这是示例:

# Define attention mechanism
attn_mech = tf.contrib.seq2seq.LuongMonotonicAttention(
    num_units = attention_unit_size, memory = decoder_outputs,
    memory_sequence_length = input_lengths)

# Define attention cell
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
    cell = decoder_cell, attention_mechanism = attn_mech,
    alignment_history=True)

# Define train helper
train_helper = tf.contrib.seq2seq.TrainingHelper(
    inputs = encoder_inputs,
    sequence_length = input_lengths)

# Define decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
    cell = attn_cell,
    helper = train_helper, initial_state=decoder_initial_state)

# Dynamic decoding
dec_outputs, dec_states, _ = tf.contrib.seq2seq.dynamic_decode(decoder)

然后在 session 中,您可以访问以下权重:
with tf.Session() as sess:
    ...
    alignments = sess.run(dec_states.alignment_history.stack(), feed_dict)

最后,您可以像这样可视化注意力(排列):
def plot_attention(attention_map, input_tags = None, output_tags = None):
    attn_len = len(attention_map)

    # Plot the attention_map
    plt.clf()
    f = plt.figure(figsize=(15, 10))
    ax = f.add_subplot(1, 1, 1)

    # Add image
    i = ax.imshow(attention_map, interpolation='nearest', cmap='Blues')

    # Add colorbar
    cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
    cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
    cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2)

    # Add labels
    ax.set_yticks(range(attn_len))
    if output_tags != None:
      ax.set_yticklabels(output_tags[:attn_len])

    ax.set_xticks(range(attn_len))
    if input_tags != None:
      ax.set_xticklabels(input_tags[:attn_len], rotation=45)

    ax.set_xlabel('Input Sequence')
    ax.set_ylabel('Output Sequence')

    # add grid and legend
    ax.grid()

    plt.show()

# input_tags - word representation of input sequence, use None to skip
# output_tags - word representation of output sequence, use None to skip
# i - index of input element in batch

plot_attention(alignments[:, i, :], input_tags, output_tags)

tensorflow - 如何从AttentionWrapper可视化注意力权重-LMLPHP

关于tensorflow - 如何从AttentionWrapper可视化注意力权重,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/44613173/

10-12 22:12