我正在尝试实现自定义损失功能

def lossFunction(self,y_true,y_pred):

     maxi=K.argmax(y_true)

     return K.mean((K.max(y_true) -(K.gather(y_pred,maxi)))**2)


训练时出现以下错误




  InvalidArgumentError(请参阅上面的回溯):indexs [5] = 51不在[0,32]中
       [[[Node:loss / dense_3_loss / Gather = Gather [Tindices = DT_INT64,Tparams = DT_FLOAT,validate_indices = true,_device =“ / job:localhost / replica:0 / task:0 / device:CPU:0”]] [dense_3 / BiasAdd,metrics / acc / ArgMax)]]




模型总结



_________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to
====================================================================================================
input_1 (InputLayer)             (None, 64, 50, 1)     0
____________________________________________________________________________________________________
input_2 (InputLayer)             (None, 64, 50, 1)     0
____________________________________________________________________________________________________
conv2d_1 (Conv2D)                (None, 32, 25, 16)    272         input_1[0][0]
____________________________________________________________________________________________________
conv2d_2 (Conv2D)                (None, 32, 25, 16)    272         input_2[0][0]
____________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)   (None, 16, 12, 16)    0           conv2d_1[0][0]
____________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)   (None, 16, 12, 16)    0           conv2d_2[0][0]
____________________________________________________________________________________________________
conv2d_3 (Conv2D)                (None, 15, 11, 32)    2080        max_pooling2d_1[0][0]
____________________________________________________________________________________________________
conv2d_4 (Conv2D)                (None, 15, 11, 32)    2080        max_pooling2d_2[0][0]
____________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)   (None, 8, 6, 32)      0           conv2d_3[0][0]
____________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)   (None, 8, 6, 32)      0           conv2d_4[0][0]
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 1536)          0           max_pooling2d_3[0][0]
____________________________________________________________________________________________________
flatten_2 (Flatten)              (None, 1536)          0           max_pooling2d_4[0][0]
____________________________________________________________________________________________________
concatenate_1 (Concatenate)      (None, 3072)          0           flatten_1[0][0]
                                                                   flatten_2[0][0]
____________________________________________________________________________________________________
input_3 (InputLayer)             (None, 256)           0
____________________________________________________________________________________________________
concatenate_2 (Concatenate)      (None, 3328)          0           concatenate_1[0][0]
                                                                   input_3[0][0]
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 512)           1704448     concatenate_2[0][0]
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 256)           131328      dense_1[0][0]
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 256)           65792       dense_2[0][0]
====================================================================================================
Total params: 1,906,272
Trainable params: 1,906,272
Non-trainable params: 0

最佳答案

Argmax从最后一个轴取,而Gather从第一个轴取。两个轴上的元素数量都不相同,因此这是可以预期的。

对于仅在类上工作的对象,请使用最后一个轴,因此我们将围绕collect方法进行古怪的操作:

def lossFunction(self,y_true,y_pred):

    maxi=K.argmax(y_true) #ok

    #invert the axes
    y_pred = K.permute_dimensions(y_pred,(1,0))

    return K.mean((K.max(y_true,axis=-1) -(K.gather(y_pred,maxi)))**2)

07-25 20:24