为分类问题定义自定义损失函数时,是否可以访问y_true和y_pred的特定元素?
用例:多标签分类问题,如果我预测5类为假阳性,我想额外惩罚模型,即y_true[5] == 0
但y_pred[5] == 1
我将损失定义为:
def loss(y_true, y_pred):
wt = 10 if (y_true[5]==0 and y_pred[5]==1) else 1
return wt * binary_crossentropy(y_true, y_pred)
我还尝试检查
K.gather(y_true, 5) == 0
,但似乎没有。我的批处理大小> 1(256),并且我正在使用
fit_generator
-如果有任何区别。谢谢! 最佳答案
有没有办法访问y_true和y_pred的特定元素?
Keras张量的索引与numpy数组的索引相似。唯一的区别是结果是Keras张量。因此,您应该随后使用Keras操作。
损失函数的可能实现
例如,以下是损失函数的实现方式:
def loss(y_true, y_pred):
a = K.equal(y_true[:, 5], 0)
b = K.greater(y_pred[:, 5], 0.5)
condition = K.cast(a, 'float') * K.cast(b, 'float')
wt = 10 * condition + (1 - condition)
return K.mean(wt[:, None] * K.binary_crossentropy(y_true, y_pred), axis=-1)
注意:未经测试。
关于python - 访问自定义损失函数中的特定元素?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/57229998/