为分类问题定义自定义损失函数时,是否可以访问y_true和y_pred的特定元素?

用例:多标签分类问题,如果我预测5类为假阳性,我想额外惩罚模型,即y_true[5] == 0y_pred[5] == 1

我将损失定义为:

def loss(y_true, y_pred):
      wt = 10 if (y_true[5]==0 and y_pred[5]==1) else 1
      return wt * binary_crossentropy(y_true, y_pred)


我还尝试检查K.gather(y_true, 5) == 0,但似乎没有。

我的批处理大小> 1(256),并且我正在使用fit_generator-如果有任何区别。谢谢!

最佳答案

有没有办法访问y_true和y_pred的特定元素?


Keras张量的索引与numpy数组的索引相似。唯一的区别是结果是Keras张量。因此,您应该随后使用Keras操作。



损失函数的可能实现

例如,以下是损失函数的实现方式:

def loss(y_true, y_pred):
    a = K.equal(y_true[:, 5], 0)
    b = K.greater(y_pred[:, 5], 0.5)
    condition = K.cast(a, 'float') * K.cast(b, 'float')
    wt = 10 * condition + (1 - condition)
    return K.mean(wt[:, None] * K.binary_crossentropy(y_true, y_pred), axis=-1)


注意:未经测试。

关于python - 访问自定义损失函数中的特定元素?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/57229998/

10-12 21:32
查看更多