下表
来自 Professor Forcing: A New Algorithm for Training Recurrent Networks 论文。但我找不到他们计算 NLL 的代码。请问是不是单纯的二元交叉熵。我可以使用 Tensorflow tf.nn.sigmoid_cross_entropy_with_logits 函数计算它吗?
在Professional Forcing论文中,没有给出教师强制的评价结果。我训练了一个简单的 LSTM 并获得了 80.394 的 NLL。我的最后一个问题是获得 ~80 或 ~70 的可能性有多大?
更具体地说,我正在尝试逐像素生成 MNIST 图像。我的模型对每个像素进行二元预测,可以取值为 0 和 1。 logits 和标签的维度都是 [batch_size, 28*28, 1]
,其中 28 是 MNIST 图像的高度和宽度。
最佳答案
事实上,负对数似然是对数损失,或(二元)分类问题的(二元)交叉熵,但由于 MNIST 是一个多类问题,这里我们讨论分类交叉熵。它通常是首选,因为对数似然本身是负数,因此它的负数将是正数;来自 log_loss
的 scikit-learn 文档(强调):
不太确定如何使用 Tensorflow 做到这一点;这是使用 Keras 执行此操作的一种方法(为了保持代码简洁明了,我在 Keras MNIST CNN example 上构建,此处仅运行 2 个时期,因为我们只对获取 y_pred
和演示该过程感兴趣):
首先,这是 Keras 报告的测试集的分类交叉熵损失结果:
y_pred = model.predict(x_test)
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# Test loss: 0.05165324027412571
# Test accuracy: 0.9834
现在让我们看看我们如何“手动”获得这个损失结果,以防我们有我们的预测
y_pred
和真实标签 y_test
而不管使用的任何特定模型;请注意,当我们的预测和真实标签都是单热编码时,该过程适用,即:y_pred[0]
# array([2.4637930e-07, 1.0927782e-07, 1.0026793e-06, 7.6613435e-07,
# 4.1209915e-09, 1.4566888e-08, 2.3195759e-10, 9.9999702e-01,
# 4.9344425e-08, 8.6051602e-07], dtype=float32)
y_test[0]
# array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
这是前奏:
from keras import backend as K
import numpy as np
y_test = y_test.astype('float32') # necessary here, since y_pred comes in this type - check in your case with y_test.dtype and y_pred.dtype
y_test = K.constant(y_test)
y_pred = K.constant(y_pred)
g = K.categorical_crossentropy(target=y_test, output=y_pred) # tensor
ce = K.eval(g) # 'ce' for cross-entropy
ce.shape
# (10000,) # i.e. one loss quantity per sample
# sum up and divide with the no. of samples:
log_loss = np.sum(ce)/ce.shape[0]
log_loss
# 0.05165323486328125
正如您可以直观地验证的那样,出于所有实际目的,这等于上面 Keras 本身报告的损失(
score[0]
);确实:np.isclose(log_loss, score[0])
# True
虽然不完全相等,可能是由于两种方法的数值精度差异:
log_loss == score[0]
# False
希望您现在应该能够使用上述过程在任何两个单热编码的
y_true
和 y_pred
集合之间获取对数损失(例如 MNIST)...关于python - 如何在 MNIST 数据集上计算负对数似然?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/52497625/