Keras Multi Loss

在keras实现rcf网络时,需要将几个中间层输出计算loss,

在这个链接https://blog.csdn.net/u012938704/article/details/79904173提供到了一种方法方法,

这里只摘取关键代码:

# create the base pre-trained model

input_tensor = Input(shape=(299, 299, 3))

base_model = Xception(include_top=True, weights='imagenet', input_tensor=None, input_shape=None)

plot_model(base_model, to_file='xception_model.png')

base_model.layers.pop()

base_model.outputs = [base_model.layers[-1].output]

base_model.layers[-1].outbound_nodes = []

base_model.output_layers = [base_model.layers[-1]]

feature = base_model

img1 = Input(shape=(299, 299, 3), name='img_1')

img2 = Input(shape=(299, 299, 3), name='img_2')

feature1 = feature(img1)

feature2 = feature(img2)

# Three loss functions

category_predict1 = Dense(100, activation='softmax', name='ctg_out_1')(

    Dropout(0.5)(feature1)

)

category_predict2 = Dense(100, activation='softmax', name='ctg_out_2')(

    Dropout(0.5)(feature2)

)

dis = Lambda(eucl_dist, name='square')([feature1, feature2])

model = Model(inputs=[img1, img2], outputs=[category_predict1, category_predict2, judge])

model.compile(optimizer=SGD(lr=0.0001, momentum=0.9),

              loss={

                  'ctg_out_1': 'categorical_crossentropy',

                  'ctg_out_2': 'categorical_crossentropy',

                  'bin_out': 'categorical_crossentropy'},

              loss_weights={

                  'ctg_out_1': 1.,

                  'ctg_out_2': 1.,

                  'bin_out': 0.5

              },

              metrics=['accuracy'])

最终验证这种方法不可行。


在github issue里面,找到了下面这种实现方法,最终验证可行。

def prepare_model():

      inputs = Input(shape=(100,100,3))

      ...

      fc = Dense(100)(#previousLayer#)

      softmax = Softmax(fc)

      def custom_loss(y_true, y_pred):

            loss1=softmax(y_true,y_pred)

            loss2=center_loss(y_true,fc)

            return loss1+lambda*loss2

        model = Model(input, output=softmax)

        model.compile(optimizer='sgd',

              loss=custom_loss,

              metrics=['accuracy'])

      return model

你可能感兴趣的:(Keras Multi Loss)