我正在使用以下代码在keras中为NLP任务训练一个简单的模型。变量名称是火车,测试和验证集的自解释性。该数据集具有19个类别,因此网络的最后一层具有19个输出。标签也是一键编码的。

nb_classes = 19
model1 = Sequential()
model1.add(Embedding(nb_words,
                     EMBEDDING_DIM,
                     weights=[embedding_matrix],
                     input_length=MAX_SEQUENCE_LENGTH,
                     trainable=False))
model1.add(LSTM(num_lstm, dropout=rate_drop_lstm, recurrent_dropout=rate_drop_lstm))
model1.add(Dropout(rate_drop_dense))
model1.add(BatchNormalization())
model1.add(Dense(num_dense, activation=act))
model1.add(Dropout(rate_drop_dense))
model1.add(BatchNormalization())

model1.add(Dense(nb_classes, activation = 'sigmoid'))


model1.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
#One hot encode all labels
ytrain_enc = np_utils.to_categorical(train_labels)
yval_enc = np_utils.to_categorical(val_labels)
ytestenc = np_utils.to_categorical(test_labels)

model1.fit(train_data, ytrain_enc,
             validation_data=(val_data, yval_enc),
             epochs=200,
             batch_size=384,
             shuffle=True,
             verbose=1)

在第一个时期之后,这给了我这些输出。

Epoch 1/200
216632/216632 [==============================] - 2442s - loss: 0.1427 - acc: 0.9443 - val_loss: 0.0526 - val_acc: 0.9826

然后,我在测试数据集上评估我的模型,这也向我显示了0.98左右的准确性。

model1.evaluate(test_data, y = ytestenc, batch_size=384, verbose=1)

但是,标签是一键编码的,因此我需要类的预测 vector ,以便可以生成混淆矩阵等。因此,我使用

PREDICTED_CLASSES = model1.predict_classes(test_data, batch_size=384, verbose=1)
temp = sum(test_labels == PREDICTED_CLASSES)
temp/len(test_labels)
0.83

这表明总预测类别的准确度为83%,但是model1.evaluate显示的准确度为98%!我在这里做错了什么?我的损失函数可以与分类类标签一起使用吗?我为预测层选择的sigmoid激活功能可以吗?还是keras评估模型的方式不同?请提出可能出问题的建议。这是我第一次尝试制作更深的模型,因此我对这里的问题不太了解。

最佳答案

我发现了问题。 metrics=['accuracy']通过成本函数自动计算准确性。因此,使用binary_crossentropy显示的是二进制精度,而不是分类精度。使用categorical_crossentropy会自动切换到分类准确度,现在它与使用model1.predict()手动计算的准确度相同。于洋指出了多类问题的成本函数和激活函数是正确的。

附言:使用metrics=['binary_accuracy', 'categorical_accuracy']可以同时获得分类精度和二进制精度

关于machine-learning - Keras:模型。评估与模型。预测多类NLP任务中的准确性差异,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/45799474/

10-12 23:08