我正在尝试理解Google's colab code。我应该如何使用此代码:

from keras import backend as K
prediction_model = lstm_model(seq_len=1, batch_size=BATCH_SIZE, stateful=True)
prediction_model.load_weights('/tmp/bard.h5')

get_test_layer_output = K.function([prediction_model.layers[0].input],
                                  [prediction_model.layers[1].output])
layer_output = get_test_layer_output([x])[0]


看每一层之后的值?还是有其他方法查看值(不是形状)?

Layer (type)                 Output Shape              Param #
=================================================================
seed (InputLayer)            (128, 100)                0
_________________________________________________________________
embedding (Embedding)        (128, 100, 512)           131072
_________________________________________________________________
lstm (LSTM)                  (128, 100, 512)           2099200
_________________________________________________________________
lstm_1 (LSTM)                (128, 100, 512)           2099200
_________________________________________________________________
time_distributed (TimeDistri (128, 100, 256)           131328
=================================================================
Total params: 4,460,800
Trainable params: 4,460,800
Non-trainable params: 0

最佳答案

对于要在Keras模型的各层上执行的任何操作,首先,我们需要访问模型保存的keras.layers对象的列表。

model_layers = model.layers


此列表中的每个Layer对象都有自己的inputoutput张量(如果您使用的是TensorFlow后端)

input_tensor = model.layers[ layer_index ].input
output_tensor = model.layers[ layer_index ].output


如果直接使用tf.Session.run()方法运行output_tensor,则会出现错误,指出在访问层的输出之前必须将输入馈入模型。

import tensorflow as tf
import numpy as np

layer_index = 3 # The index of the layer whose output needs to be fetched

model = tf.keras.models.load_model( 'model.h5' )
out_ten = model.layers[ layer_index ].output

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    output = sess.run(  out_ten , { model.input : np.ones((2,186))}  )
    print( output )


您需要在运行模型之前使用tf.global_variables_initializer().run()初始化变量。 model.input为模型的输入提供占位符张量。

关于python - 如何获得中间层的输出?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/56282323/

10-12 23:58