Keras自定义损失函数,多个输入

Keras默认的自定义损失函数参数形式固定,一个为y_true,另一个为y_pred
例如:
 
def myloss(y_true,y_pred):
    pass
# 自定义损失函数
def myloss(y_true,y_pred,neighbor_Y):
    lamb1 = 0.7
    lamb2 = 0.3


    loss1 = K.mean(K.square(y_pred - y_true))
    neighbor_Y = Reshape((3,))(neighbor_Y)
    d1 = K.square(2 * neighbor_Y[:,1] - neighbor_Y[:,0] - neighbor_Y[:,2])
    s = neighbor_Y[:,0] + neighbor_Y[:,2]
    s = Reshape((1,))(s)
    d2 = K.square(2 * y_pred - s)
    d1 = Reshape((1,))(d1)
    d2 = Reshape((1,))(d2)
    loss2 = K.mean(K.abs(d1 - d2))
    loss = lamb1 * loss1 + lamb2 * loss2

    return loss

def dice_loss(neighbor_Y):
    def dice(y_true,y_pred):
        return myloss(y_true,y_pred,neighbor_Y)
        
    return dice

model

你可能感兴趣的:(深度学习)