我正在使用TensorFlow v:1.1,并且我想使用tf.contrib.seq2seq api实现一个序列到序列模型。
但是,我很难理解如何使用提供的所有功能(BasicDecoder,Dynamic_decode,Helper,Training Helper ...)来构建模型。
这是我的设置:我想将一系列特征向量:(batch_size,encoder_max_seq_len,feature_dim)“翻译”成不同长度的序列(batch_size,decoder_max_len,1)。

我已经有一个编码器,它是带有LSTM单元的RNN,并且得到了它的最终状态,我希望将其作为初始输入馈送到解码器。
我已经有用于解码器MultiRNNCell LSM的单元。
您能帮我使用tf.contrib.seq2seq2和dynamic_decode的功能来构建最后一部分吗(示例代码或解释将不胜感激)?

这是我的代码:

import tensorflow as tf
from tensorflow.contrib import seq2seq
from tensorflow.contrib import rnn
import math

from data import gen_sum_2b2

class Seq2SeqModel:
def __init__(self,
             in_size,
             out_size,
             embed_size,
             n_symbols,
             cell_type,
             n_units,
             n_layers):
    self.in_size = in_size
    self.out_size = out_size
    self.embed_size = embed_size
    self.n_symbols = n_symbols
    self.cell_type = cell_type
    self.n_units = n_units
    self.n_layers = n_layers

    self.build_graph()

def build_graph(self):
    self.init_placeholders()
    self.init_cells()
    self.encoder()
    self.decoder_train()
    self.loss()
    self.training()

def init_placeholders(self):
    with tf.name_scope('Placeholders'):
        self.encoder_inputs = tf.placeholder(shape=(None, None, self.in_size),
                                             dtype=tf.float32, name='encoder_inputs')
        self.decoder_targets = tf.placeholder(shape=(None, None),
                                              dtype=tf.int32, name='decoder_targets')
        self.seqs_len = tf.placeholder(dtype=tf.int32)
        self.batch_size = tf.placeholder(tf.int32, name='dynamic_batch_size')
        self.max_len = tf.placeholder(tf.int32, name='dynamic_seq_len')
        decoder_inputs = tf.reshape(self.decoder_targets, shape=(self.batch_size,
                                    self.max_len, self.out_size))
        self.decoder_inputs = tf.cast(decoder_inputs, tf.float32)
        self.eos_step = tf.ones([self.batch_size, 1], dtype=tf.float32, name='EOS')
        self.pad_step = tf.zeros([self.batch_size, 1], dtype=tf.float32, name='PAD')

def RNNCell(self):
    c = self.cell_type(self.n_units, reuse=None)
    c = rnn.MultiRNNCell([self.cell_type(self.n_units) for i in range(self.n_layers)])
    return c

def init_cells(self):
    with tf.variable_scope('RNN_enc_cell'):
        self.encoder_cell = self.RNNCell()
    with tf.variable_scope('RNN_dec_cell'):
        self.decoder_cell = rnn.OutputProjectionWrapper(self.RNNCell(), self.n_symbols)

def encoder(self):
    with tf.variable_scope('Encoder'):
        self.init_state = self.encoder_cell.zero_state(self.batch_size, tf.float32)
        _, self.encoder_final_state = tf.nn.dynamic_rnn(self.encoder_cell, self.encoder_inputs,
                                                        initial_state=self.init_state)

最佳答案

解码层:

解码由两部分组成,因为它们在traininginference期间有所不同:


特定时间步长的解码器输入始终来自输出
上一个时间步。但是在训练过程中,输出是固定的
到实际目标(将实际目标作为输入反馈),这可以提高性能。


这两个都是使用tf.contrib.seq2seq中的方法处理的。


decoder的主要功能是:seq2seq.dynamic decoder()执行动态解码:

tf.contrib.seq2seq.dynamic_decode(decoder,maximum_iterations)

这需要一个Decoder实例和maximum_iterations=maximum seq length作为输入。

1.1 Decoder实例来自:

seq2seq.BasicDecoder(cell, helper, initial_state,output_layer)

输入为:cell(RNNCell实例),helper(帮助程序实例),initial_state(解码器的初始状态,应为编码器的输出状态)和output_layer(可选的密集层,如输出进行预测)

1.2 RNNCell实例可以是rnn.MultiRNNCell()

1.3 helper实例是与traininginference不同的实例。在training期间,我们希望将输入馈送到解码器,而在inference期间,我们希望将time-step (t)中的解码器输出作为输入传递到time step (t+1)中的解码器。

培训:我们使用辅助功能:
seq2seq.TrainingHelper(inputs, sequence_length),仅读取输入。

推断:我们调用辅助函数:
seq2seq.GreedyEmbeddingHelper() or seqseq.SampleEmbeddingHelper(),不同之处在于是否使用输出的argmax() or sampling(from a distribution)并将结果传递给嵌入层以获取下一个输入。


放在一起:Seq2Seq模型


encoder layer获取编码器状态,并将其作为initial_state传递给解码器。
使用decoder train获取decoder inferenceseq2seq.dynamic_decoder()的输出。当调用这两种方法时,请确保权重是共享的。 (使用variable_scope重用权重)
然后使用损失函数seq2seq.sequence_loss训练网络。


给出了示例代码herehere

关于tensorflow - 使用seq2seq API(1.1版及更高版本)的Tensorflow Sequence到序列模型,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/43622778/

10-12 21:55