如何从张量流中的RNN模型中提取单元状态和隐藏状态

如何从张量流中的RNN模型中提取单元状态和隐藏状态

本文介绍了如何从张量流中的RNN模型中提取单元状态和隐藏状态?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我是TensorFlow的新手,在理解RNN模块方面遇到困难.我正在尝试从LSTM提取隐藏/单元状态.对于我的代码,我正在使用 https://github.com/aymericdamien/TensorFlow-Examples.

I am new to TensorFlow and have difficulties understanding the RNN module. I am trying to extract hidden/cell states from an LSTM.For my code, I am using the implementation from https://github.com/aymericdamien/TensorFlow-Examples.

# tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])

# Define weights
weights = {'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))}
biases = {'out': tf.Variable(tf.random_normal([n_classes]))}

def RNN(x, weights, biases):
    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input)
    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)

    # Permuting batch_size and n_steps
    x = tf.transpose(x, [1, 0, 2])
    # Reshaping to (n_steps*batch_size, n_input)
    x = tf.reshape(x, [-1, n_input])
    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    x = tf.split(0, n_steps, x)

    # Define a lstm cell with tensorflow
    #with tf.variable_scope('RNN'):
    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True)

    # Get lstm cell output
        outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(outputs[-1], weights['out']) + biases['out'], states

pred, states = RNN(x, weights, biases)

# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Initializing the variables
init = tf.initialize_all_variables()

现在,我想提取预测中每个时间步的单元格/隐藏状态.状态存储在(c,h)形式的LSTMStateTuple中,我可以通过评估print states来发现.但是,尝试调用print states.c.eval()(根据文档,它应该为我提供张量states.c的值)会产生错误,指出我的变量没有初始化,即使我在预测某些内容后立即对其进行了调用.的代码在这里:

Now I want to extract the cell/hidden state for each time step in a prediction. The state is stored in a LSTMStateTuple of the form (c,h), which I can find out by evaluating print states. However, trying to call print states.c.eval() (which according to the documentation should give me values in the tensor states.c), yields an error stating that my variables are not initialized even though I am calling it right after I am predicting something. The code for this is here:

# Launch the graph
with tf.Session() as sess:
    sess.run(init)
    step = 1
    # Keep training until reach max iterations
    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='RNN'):
        print v.name
    while step * batch_size < training_iters:
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # Reshape data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, n_steps, n_input))
        # Run optimization op (backprop)
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})

        print states.c.eval()
        # Calculate batch accuracy
        acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})

        step += 1
    print "Optimization Finished!"

,错误消息是

InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float
     [[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

状态在tf.all_variables()中也不可见,只有经过训练的矩阵/偏置张量(如此处所述: Tensorflow:在LSTM中显示或保存忘记门的值).我不想从头开始构建整个LSTM,尽管由于我在states变量中有状态,所以我只需要调用它即可.

The states are also not visible in tf.all_variables(), only the trained matrix/bias tensors (as described here: Tensorflow: show or save forget gate values in LSTM). I don't want to build the whole LSTM from scratch though since I have the states in the states variable, I just need to call it.

推荐答案

您可以简单地以与收集准确性相同的方式收集states的值.

You may simply collect the values of the states in the same way accuracy is collected.

我想pred, states, acc = sess.run(pred, states, accuracy, feed_dict={x: batch_x, y: batch_y})应该可以正常工作.

I guess, pred, states, acc = sess.run(pred, states, accuracy, feed_dict={x: batch_x, y: batch_y}) should work perfectly fine.

这篇关于如何从张量流中的RNN模型中提取单元状态和隐藏状态?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

09-03 10:06