VGG16—perceptual loss in keras感知损失【Keras】

前言

正常的损失加上感知损失,肯定需要自定义合适的loss function。在keras中,自定义loss function :
先考虑keras中的loss,如下:

def mean_squared_error(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=-1)

如果要定义自己的感知损失:

def model_loss(y_true, y_pred):
    inp = Input(shape=(128, 128, 1))
    x = Dense(2)(inp)
    x = Flatten()(x)

    model = Model(inputs=[inp], outputs=[x])
    a = model(y_pred)
    b = model(y_true)

    # calculate MSE
    mse = K.mean(K.square(a - b))
    return mse

上面的代码只是一个概念展示,但是能够指导我们应该如何去做。

举例

因为在做医疗影像,所以一般的医疗影像都不是正常的RGB三通道图像,往往是3D的影像或者是1通道影像,如下图所示:
VGG16—perceptual loss in keras感知损失【Keras】_第1张图片
脑部的失状图影像,图像大小是(182, 218,1),训练的神经网络输入是(160,200,1),那么如果要使用VGG16的感知损失的话,需要将其复制为3通道,具体细节代码如下:

def VGGloss(y_true, y_pred):  # Note the parameter order
    from keras.applications.vgg16 import VGG16
    mod = VGG16(include_top=False, weights='imagenet')
    pred = K.concatenate([y_pred, y_pred, y_pred])
    true = K.concatenate([y_true, y_true, y_true])
    vggmodel = mod 
    f_p = vggmodel(pred)  
    f_t = vggmodel(true)  
    return K.mean(K.square(f_p - f_t)) 

使用:

model.compile(optimizer=Adam, loss = losses.VGGloss) 

参考:https://stackoverflow.com/questions/43914931/vgg-perceptual-loss-in-keras

你可能感兴趣的:(深度学习笔记)