我想在keras model.fit中使用class_weight参数来处理不平衡的训练数据。通过查看一些文档,我了解到我们可以通过这样的字典:
class_weight = {0 : 1,
1: 1,
2: 5}
(在此示例中,类别2将在损失函数中获得更高的惩罚。)
问题是我的网络输出具有一键编码,即class-0 =(1、0、0),class-1 =(0、1、0)和class-3 =(0、0、1)。
我们如何将class_weight用于一键编码输出?
通过查看some codes in Keras,看起来
_feed_output_names
包含一个输出类列表,但是在我的情况下,model.output_names
/model._feed_output_names
返回['dense_1']
相关:How to set class weights for imbalanced classes in Keras?
最佳答案
这是一个更短,更快的解决方案。如果您的热编码y是np.array:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
y_integers = np.argmax(y, axis=1)
class_weights = compute_class_weight('balanced', np.unique(y_integers), y_integers)
d_class_weights = dict(enumerate(class_weights))
然后可以将
d_class_weights
传递给class_weight
中的.fit
。关于python - Keras:一键编码的类权重(class_weight),我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/43481490/