本文介绍了Tensorflow raw_rnn从嵌入矩阵中检索形状为BATCH x DIM的张量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在实现编码器-解码器lstm,在该过程中,我必须在编码器的每一步进行自定义计算.因此,我正在使用raw_rnn.但是,我在访问时间步长time时从嵌入中访问形状为Batch x Time steps x Embedding dimensionality的元素时遇到了问题.

I am implementing encoder-decoder lstm, where I have to do custom computation at each step of the encoder. So, I am using raw_rnn. However, I am facing a problem accessing an element from the embeddings which is shaped as Batch x Time steps x Embedding dimensionality at time step time.

这是我的设置:

import tensorflow as tf
import numpy as np

batch_size, max_time, input_embedding_size = 5, 10, 16
vocab_size, num_units = 50, 64

encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')

embeddings = tf.Variable(tf.random_uniform([vocab_size + 2, input_embedding_size], -1.0, 1.0),
                         dtype=tf.float32, name='embeddings')
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

cell = tf.contrib.rnn.LSTMCell(num_units)

主要部分:

with tf.variable_scope('ReaderNetwork'):
def loop_fn_initial():
    init_elements_finished = (0 >= encoder_inputs_length)
    init_input = cell.zero_state(batch_size, tf.float32)
    init_cell_state = None
    init_cell_output = None
    init_loop_state = None
    return (init_elements_finished, init_input,
            init_cell_state, init_cell_output, init_loop_state)


def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):
    def get_next_input():
        # **TODO** read tensor of shape BATCH X EMBEDDING_DIM from encoder_inputs_embedded
        #  which has shape BATCH x TIME_STEPS x EMBEDDING_DIM

    elements_finished = (time >= encoder_inputs_length)
    finished = tf.reduce_all(elements_finished)  # boolean scalar
    input_val = tf.cond(finished,
                        true_fn=lambda: tf.zeros([batch_size, input_embedding_size]), false_fn=get_next_input)
    state = previous_state
    output = previous_output
    loop_state = None
    return elements_finished, input_val, state, output, loop_state


def loop_fn(time, previous_output, previous_state, previous_loop_state):
    if previous_state is None:  # time = 0
        assert previous_output is None and previous_state is None
        return loop_fn_initial()
    return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

运行部分:

reader_loop = loop_fn
encoder_outputs_ta, encoder_final_state, _ = tf.nn.raw_rnn(cell, loop_fn=reader_loop)
outputs = encoder_outputs_ta.stack()

def next_batch():
    return {
        encoder_inputs: np.random.random((batch_size, max_time)),
        encoder_inputs_length: [max_time] * batch_size
    }

init = tf.global_variables_initializer()
with tf.Session() as s:
    s.run(init)
    outs = s.run([outputs], feed_dict=next_batch())
    print len(outs), outs[0].shape

问题:如何一次访问部分嵌入并返回形状为batch x embedding dim的张量?参见loop_fn_transition中的功能get_next_input.

Question: How to access part of the embeddings at a time step and return a tensor of shape batch x embedding dim? See function get_next_input within loop_fn_transition.

谢谢.

推荐答案

我能够解决此问题.由于嵌入的形状为Batch x Time steps x Embedding dimensionality,因此我在time维度上进行了切片.所得张量的形状为(?, embedding dimensionality).还需要显式设置结果张量的形状,以避免出现错误:

I was able fix the problem. Since embeddings have shape Batch x Time steps x Embedding dimensionality I slice out on time dimension. The resulting tensor has shape (?, embedding dimensionality). It is also required to explicitly set the shape of the resulting tensor in order to avoid the error:

这是相关部分:

def get_next_input():
    embedded_value = encoder_inputs_embedded[:, time, :]
    embedded_value.set_shape([batch_size, input_embedding_size])
    return embedded_value

任何人都可以确认这是否是解决问题的正确方法吗?

以下是完整的代码供参考:

Here is the complete code for reference:

import tensorflow as tf
import numpy as np

batch_size, max_time, input_embedding_size = 5, 10, 16
vocab_size, num_units = 50, 64

encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')

embeddings = tf.Variable(tf.random_uniform([vocab_size + 2, input_embedding_size], -1.0, 1.0),
                         dtype=tf.float32, name='embeddings')
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

cell = tf.contrib.rnn.LSTMCell(num_units)
W = tf.Variable(tf.random_uniform([num_units, vocab_size], -1, 1), dtype=tf.float32, name='W_reader')
b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32, name='b_reader')
go_time_slice = tf.ones([batch_size], dtype=tf.int32, name='GO') * 1
go_step_embedded = tf.nn.embedding_lookup(embeddings, go_time_slice)


with tf.variable_scope('ReaderNetwork'):
    def loop_fn_initial():
        init_elements_finished = (0 >= encoder_inputs_length)
        init_input = go_step_embedded
        init_cell_state = cell.zero_state(batch_size, tf.float32)
        init_cell_output = None
        init_loop_state = None
        return (init_elements_finished, init_input,
                init_cell_state, init_cell_output, init_loop_state)

    def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):
        def get_next_input():
            embedded_value = encoder_inputs_embedded[:, time, :]
            embedded_value.set_shape([batch_size, input_embedding_size])
            return embedded_value

        elements_finished = (time >= encoder_inputs_length)
        finished = tf.reduce_all(elements_finished)  # boolean scalar
        next_input = tf.cond(finished,
                             true_fn=lambda: tf.zeros([batch_size, input_embedding_size], dtype=tf.float32),
                             false_fn=get_next_input)
        state = previous_state
        output = previous_output
        loop_state = None
        return elements_finished, next_input, state, output, loop_state


    def loop_fn(time, previous_output, previous_state, previous_loop_state):
        if previous_state is None:  # time = 0
            return loop_fn_initial()
        return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

reader_loop = loop_fn
encoder_outputs_ta, encoder_final_state, _ = tf.nn.raw_rnn(cell, loop_fn=reader_loop)
outputs = encoder_outputs_ta.stack()


def next_batch():
    return {
        encoder_inputs: np.random.randint(0, vocab_size, (batch_size, max_time)),
        encoder_inputs_length: [max_time] * batch_size
    }


init = tf.global_variables_initializer()
with tf.Session() as s:
    s.run(init)
    outs = s.run([outputs], feed_dict=next_batch())
    print len(outs), outs[0].shape

这篇关于Tensorflow raw_rnn从嵌入矩阵中检索形状为BATCH x DIM的张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

09-15 03:52