本文介绍了如何制作一个自定义损失函数,该函数使用Keras中网络的先前输出?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试构建一个自定义损失函数,该函数将从网络中获取先前的输出(先前迭代的输出)并将其与当前输出一起使用.

I am trying to build a custom loss function that takes the previous output(output from the previous iteration) from the network and use it with the current output.

这是我想要做的,但我不知道如何完成

Here is what I am trying to do, but I don't know how to complete it

def l_loss(prev_output):

    def loss(y_true, y_pred):

        pix_loss = K.mean(K.square(y_pred - y_true), axis=-1)

        pase = K.variable(100)

        diff = K.mean(K.abs(prev_output - y_pred))
        movement_loss = K.abs(pase - diff)
        total_loss = pix_loss + movement_loss

        return total_loss
    return loss

self.model.compile(optimizer=Adam(0.001, beta_1=0.5, beta_2=0.9),
 loss=l_loss(?))

希望你能帮助我.

推荐答案

这是我尝试过的:

from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import Sequential
from tensorflow.keras import backend as K

class MovementLoss(object):
  def __init__(self):
    self.var = None

  def __call__(self, y_true, y_pred, sample_weight=None):
    mse = K.mean(K.square(y_true - y_pred), axis=-1)
    if self.var is None:
      z = np.zeros((32,))
      self.var = K.variable(z)
    delta = K.update(self.var, mse - self.var)
    return mse + delta


def make_model():
  model = Sequential()
  model.add(Dense(1, input_shape=(4,)))
  loss = MovementLoss()
  model.compile('adam', loss)
  return model

model = make_model()
model.summary()


使用示例测试数据.

import numpy as np

X = np.random.rand(32, 4)

POLY = [1.0, 2.0, 0.5, 3.0]
def test_fn(xi):
  return np.dot(xi, POLY)

Y = np.apply_along_axis(test_fn, 1, X)

history = model.fit(X, Y, epochs=4)

我确实看到损失函数在我看来受到最后一批增量的影响.请注意,损失函数的详细信息与您的应用程序不同.

I do see the loss function oscillate in a way that appears to me is influenced by the last batch delta. Note that the loss function details are not according to your application.

关键的一步是,K.update步骤必须是图形的一部分(据我所知).

The crucial step is that the K.update step must be part of the graph (as far as I understand it).

这是通过以下方式实现的:

That is achieved by:

delta = K.update(var, delta)
return x + delta

这篇关于如何制作一个自定义损失函数,该函数使用Keras中网络的先前输出?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-19 17:44