问题描述
我正在实现编码器-解码器lstm,在该过程中,我必须在编码器的每一步进行自定义计算.因此,我正在使用raw_rnn
.但是,我在访问时间步长time
时从嵌入中访问形状为Batch x Time steps x Embedding dimensionality
的元素时遇到了问题.
I am implementing encoder-decoder lstm, where I have to do custom computation at each step of the encoder. So, I am using raw_rnn
. However, I am facing a problem accessing an element from the embeddings which is shaped as Batch x Time steps x Embedding dimensionality
at time step time
.
这是我的设置:
import tensorflow as tf
import numpy as np
batch_size, max_time, input_embedding_size = 5, 10, 16
vocab_size, num_units = 50, 64
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')
embeddings = tf.Variable(tf.random_uniform([vocab_size + 2, input_embedding_size], -1.0, 1.0),
dtype=tf.float32, name='embeddings')
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
cell = tf.contrib.rnn.LSTMCell(num_units)
主要部分:
with tf.variable_scope('ReaderNetwork'):
def loop_fn_initial():
init_elements_finished = (0 >= encoder_inputs_length)
init_input = cell.zero_state(batch_size, tf.float32)
init_cell_state = None
init_cell_output = None
init_loop_state = None
return (init_elements_finished, init_input,
init_cell_state, init_cell_output, init_loop_state)
def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):
def get_next_input():
# **TODO** read tensor of shape BATCH X EMBEDDING_DIM from encoder_inputs_embedded
# which has shape BATCH x TIME_STEPS x EMBEDDING_DIM
elements_finished = (time >= encoder_inputs_length)
finished = tf.reduce_all(elements_finished) # boolean scalar
input_val = tf.cond(finished,
true_fn=lambda: tf.zeros([batch_size, input_embedding_size]), false_fn=get_next_input)
state = previous_state
output = previous_output
loop_state = None
return elements_finished, input_val, state, output, loop_state
def loop_fn(time, previous_output, previous_state, previous_loop_state):
if previous_state is None: # time = 0
assert previous_output is None and previous_state is None
return loop_fn_initial()
return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)
运行部分:
reader_loop = loop_fn
encoder_outputs_ta, encoder_final_state, _ = tf.nn.raw_rnn(cell, loop_fn=reader_loop)
outputs = encoder_outputs_ta.stack()
def next_batch():
return {
encoder_inputs: np.random.random((batch_size, max_time)),
encoder_inputs_length: [max_time] * batch_size
}
init = tf.global_variables_initializer()
with tf.Session() as s:
s.run(init)
outs = s.run([outputs], feed_dict=next_batch())
print len(outs), outs[0].shape
问题:如何一次访问部分嵌入并返回形状为batch x embedding dim
的张量?参见loop_fn_transition
中的功能get_next_input
.
Question: How to access part of the embeddings at a time step and return a tensor of shape batch x embedding dim
? See function get_next_input
within loop_fn_transition
.
谢谢.
推荐答案
我能够解决此问题.由于嵌入的形状为Batch x Time steps x Embedding dimensionality
,因此我在time
维度上进行了切片.所得张量的形状为(?, embedding dimensionality)
.还需要显式设置结果张量的形状,以避免出现错误:
I was able fix the problem. Since embeddings have shape Batch x Time steps x Embedding dimensionality
I slice out on time
dimension. The resulting tensor has shape (?, embedding dimensionality)
. It is also required to explicitly set the shape of the resulting tensor in order to avoid the error:
这是相关部分:
def get_next_input():
embedded_value = encoder_inputs_embedded[:, time, :]
embedded_value.set_shape([batch_size, input_embedding_size])
return embedded_value
任何人都可以确认这是否是解决问题的正确方法吗?
以下是完整的代码供参考:
Here is the complete code for reference:
import tensorflow as tf
import numpy as np
batch_size, max_time, input_embedding_size = 5, 10, 16
vocab_size, num_units = 50, 64
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')
embeddings = tf.Variable(tf.random_uniform([vocab_size + 2, input_embedding_size], -1.0, 1.0),
dtype=tf.float32, name='embeddings')
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
cell = tf.contrib.rnn.LSTMCell(num_units)
W = tf.Variable(tf.random_uniform([num_units, vocab_size], -1, 1), dtype=tf.float32, name='W_reader')
b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32, name='b_reader')
go_time_slice = tf.ones([batch_size], dtype=tf.int32, name='GO') * 1
go_step_embedded = tf.nn.embedding_lookup(embeddings, go_time_slice)
with tf.variable_scope('ReaderNetwork'):
def loop_fn_initial():
init_elements_finished = (0 >= encoder_inputs_length)
init_input = go_step_embedded
init_cell_state = cell.zero_state(batch_size, tf.float32)
init_cell_output = None
init_loop_state = None
return (init_elements_finished, init_input,
init_cell_state, init_cell_output, init_loop_state)
def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):
def get_next_input():
embedded_value = encoder_inputs_embedded[:, time, :]
embedded_value.set_shape([batch_size, input_embedding_size])
return embedded_value
elements_finished = (time >= encoder_inputs_length)
finished = tf.reduce_all(elements_finished) # boolean scalar
next_input = tf.cond(finished,
true_fn=lambda: tf.zeros([batch_size, input_embedding_size], dtype=tf.float32),
false_fn=get_next_input)
state = previous_state
output = previous_output
loop_state = None
return elements_finished, next_input, state, output, loop_state
def loop_fn(time, previous_output, previous_state, previous_loop_state):
if previous_state is None: # time = 0
return loop_fn_initial()
return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)
reader_loop = loop_fn
encoder_outputs_ta, encoder_final_state, _ = tf.nn.raw_rnn(cell, loop_fn=reader_loop)
outputs = encoder_outputs_ta.stack()
def next_batch():
return {
encoder_inputs: np.random.randint(0, vocab_size, (batch_size, max_time)),
encoder_inputs_length: [max_time] * batch_size
}
init = tf.global_variables_initializer()
with tf.Session() as s:
s.run(init)
outs = s.run([outputs], feed_dict=next_batch())
print len(outs), outs[0].shape
这篇关于Tensorflow raw_rnn从嵌入矩阵中检索形状为BATCH x DIM的张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!