我想在keras model.fit中使用class_weight参数来处理不平衡的训练数据。通过查看一些文档,我了解到我们可以通过这样的字典:

class_weight = {0 : 1,
    1: 1,
    2: 5}

(在此示例中,类别2将在损失函数中获得更高的惩罚。)

问题是我的网络输出具有一键编码,即class-0 =(1、0、0),class-1 =(0、1、0)和class-3 =(0、0、1)。

我们如何将class_weight用于一键编码输出?

通过查看some codes in Keras,看起来_feed_output_names包含一个输出类列表,但是在我的情况下,model.output_names/model._feed_output_names返回['dense_1']
相关:How to set class weights for imbalanced classes in Keras?

最佳答案

这是一个更短,更快的解决方案。如果您的热编码y是np.array:

import numpy as np
from sklearn.utils.class_weight import compute_class_weight

y_integers = np.argmax(y, axis=1)
class_weights = compute_class_weight('balanced', np.unique(y_integers), y_integers)
d_class_weights = dict(enumerate(class_weights))

然后可以将d_class_weights传递给class_weight中的.fit

关于python - Keras:一键编码的类权重(class_weight),我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/43481490/

10-12 17:40
查看更多