keras 实现多任务学习

def deep_multi_model(feature_dim, cvr_label_dim, profit_label_dim):
    inputs = Input(shape=(feature_dim,))
    dense_1 = Dense(512, activation='relu')(inputs)
    dense_2 = Dense(384, activation='relu')(dense_1)
    dense_3 = Dense(256, activation='relu')(dense_2)
    drop_1 = Dropout(0.2)(dense_3)
    dense_4 = Dense(128, activation='relu')(drop_1)
    dense_5 = Dense(64, activation='relu')(dense_4)

    output_1 = Dense(32, activation='relu')(dense_5)
    output_cvr = Dense(cvr_label_dim, activation='softmax', name='output_cvr')(output_1)

    output_2 = Dense(16, activation='relu')(dense_5)
    output_profit = Dense(profit_label_dim, activation='softmax', name='output_profit')(output_2)

    # 模型有两个输出 output_cvr, output_profit
    model = Model(inputs=inputs, outputs=[output_cvr, output_profit])
    model.summary()

    # 模型有两个 loss, 都是 categorical_crossentropy
    # loss 的 key 需要和模型的 output 层的 name 保持一致
    model.compile(optimizer='adam',
              loss={'output_cvr': 'categorical_crossentropy', 'output_profit': 'categorical_crossentropy'},
              loss_weights={'output_cvr':1, 'output_profit': 0.3},
              metrics=[categorical_accuracy])
    
    return model


# 产生训练数据的生成器
# 模型只有一个 input 有两个 output,所以 yield 格式为如下
def generate_arrays(X_train, y_train_cvr_label, y_train_profit_label):
    while True:
        for x, y_cvr, y_profit in zip(X_train, y_train_cvr_label, y_train_profit_label):
            yield (x[np.newaxis, :], {'output_cvr': y_cvr[np.newaxis, :], 'output_profit': y_profit[np.newaxis, :]})


# fit_generator 进行 fit 训练
def train_multi(X_train, y_train_cvr_label, y_train_profit_label, X_test, y_test_cvr_label, y_test_profit_label):
    feature_dim = X_train.shape[1]
    cvr_label_dim = y_train_cvr_label.shape[1]
    profit_label_dim = y_train_profit_label.shape[1]
    
    model = deep_multi_model(feature_dim, cvr_label_dim, profit_label_dim)
    
    model.summary()
    early_stopping = EarlyStopping(monitor='val_loss', patience=15, verbose=0)
    
    
    model.fit_generator(generate_arrays(X_train, y_train_cvr_label, y_train_profit_label),
                        steps_per_epoch=1024, 
                        epochs=100, 
                        validation_data=generate_arrays(X_test, y_test_cvr_label, y_test_profit_label), 
                        validation_steps=1024, 
                        callbacks=[early_stopping])

    return model


你可能感兴趣的:(算法,keras,深度学习,神经网络)