多任务学习(Multi-task)keras实现

多目标任务存在很多场景中,如多目标检测,推荐系统中的多任务学习。

多任务学习(Multi-task learning)简介

多任务学习背景:只专注于单个模型可能会忽略一些相关任务中可能提升目标任务的潜在信息,通过进行一定程度的共享不同任务之间的参数,可能会使原任务泛化更好。广义的讲,只要loss有多个就算MTL,一些别名(joint learning,learning to learn,learning with auxiliary task)

多任务学习(Multitask learning)定义:基于共享表示(shared representation),把多个相关的任务放在一起学习的一种机器学习方法。

多任务学习(Multitask Learning)是一种推导迁移学习方法,主任务(main tasks)使用相关任务(related tasks)的训练信号(training signal)所拥有的领域相关信息(domain-specific information),做为一直推导偏差(inductive bias)来提升主任务(main tasks)泛化效果(generalization performance)的一种机器学习方法。

多任务学习目标:通过权衡主任务与辅助的相关任务中的训练信息来提升模型的泛化性与表现。从机器学习的视角来看,MTL可以看作一种inductive transfer(先验知识),通过提供inductive bias(某种对模型的先验假设)来提升模型效果。比如,使用L1正则,我们对模型的假设模型偏向于sparse solution(参数要少)。在MTL中,这种先验是通过auxiliary task来提供,更灵活,告诉模型偏向一些其他任务,最终导致模型会泛化得更好。

多任务学习(Multi-task learning)的两种模式

深度学习中两种多任务学习模式:隐层参数的硬共享与软共享

  • 隐层参数硬共享,指的是多个任务之间共享网络的同几层隐藏层,只不过在网络的靠近输出部分开始分叉去做不同的任务。
  • 隐层参数软共享,不同的任务使用不同的网络,但是不同任务的网络参数,采用距离(L1,L2)等作为约束,鼓励参数相似化。

而本次的代码实现采用的是隐层参数硬共享,也就是两个任务共享网络浅层的参数。

多任务学习(Multi-task)keras实现_第1张图片

                                                          上图是美团使用的多任务学习框架

在使用XGBoost进行单目标训练的时候,通过把点击的样本和下单的样本都作为正样本,并对下单的样本进行上采样或者加权,来平衡点击率和下单率。但这种样本的加权方式也会有一些缺点,例如调整下单权重或者采样率的成本较高,每次调整都需要重新训练,并且对于模型来说较难用同一套参数来表达这两种混合的样本分布。针对上述问题,可以利用DNN灵活的网络结构引入了Multi-task训练。

根据业务目标,我们把点击率和下单率拆分出来,形成两个独立的训练目标,分别建立各自的Loss Function,作为对模型训练的监督和指导。DNN网络的前几层作为共享层,点击任务和下单任务共享其表达,并在BP阶段根据两个任务算出的梯度共同进行参数更新。网络在最后一个全连接层进行拆分,单独学习对应Loss的参数,从而更好地专注于拟合各自Label的分布。

Multi-task DNN的网络结构如上图所示。线上预测时,我们将Click-output和Pay-output做一个线性融合。

# 搭建双任务并训练
def get_model():
    """函数式API搭建双塔DNN模型"""

    # 输入
    user_id = tf.keras.layers.Input(shape=(1,), name="user_id")
    store_id = tf.keras.layers.Input(shape=(1,), name="store_id")
    sku_id = tf.keras.layers.Input(shape=(1,), name="sku_id")
    search_keyword = tf.keras.layers.Input(shape=(1,), name="search_keyword")
    category_id = tf.keras.layers.Input(shape=(1,), name="category_id")
    brand_id = tf.keras.layers.Input(shape=(1,), name="brand_id")
    ware_type = tf.keras.layers.Input(shape=(1,), name="ware_type")

    # user特征
    user_vector = tf.keras.layers.concatenate([
        tf.keras.layers.Embedding(num_user_ids, 32)(user_id),
        tf.keras.layers.Embedding(num_store_ids, 8)(store_id),
        tf.keras.layers.Embedding(num_search_keywords, 16)(search_keyword)
    ])
    user_vector = tf.keras.layers.Dense(32, activation='relu')(user_vector)
    user_vector = tf.keras.layers.Dense(8, activation='relu',
                               name="user_embedding", kernel_regularizer='l2')(user_vector)

    # item特征
    movie_vector = tf.keras.layers.concatenate([
        tf.keras.layers.Embedding(num_sku_ids, 32)(sku_id),
        tf.keras.layers.Embedding(num_category_ids, 8)(category_id),
        tf.keras.layers.Embedding(num_brand_ids, 8)(brand_id),
        tf.keras.layers.Embedding(num_ware_types, 2)(ware_type)
    ])
    movie_vector = tf.keras.layers.Dense(32, activation='relu')(movie_vector)
    movie_vector = tf.keras.layers.Dense(8, activation='relu',
                                name="movie_embedding", kernel_regularizer='l2')(movie_vector)

    x = tf.keras.layers.concatenate([user_vector,movie_vector])
    out1 = tf.keras.layers.Dense(16,activation = 'relu')(x)
    out1 = tf.keras.layers.Dense(8,activation = 'relu')(out1)
    out1 = tf.keras.layers.Dense(1, activation='sigmoid',name = 'out1')(out1)
    
    out2 = tf.keras.layers.Dense(16,activation = 'relu')(x)
    out2 = tf.keras.layers.Dense(8,activation = 'relu')(out2)
    out2 = tf.keras.layers.Dense(1, activation='sigmoid',name = 'out2')(out2)

    return tf.keras.models.Model(inputs=[user_id, sku_id, store_id, search_keyword, category_id, brand_id,ware_type], 
                              outputs=[out1,out2])

模型代码部分

这里模型构建有两点需要注意:

  • 各个任务的输出层一定要命名,比如笔者这个模型的点击率输出层Dense(1, activation='sigmod',name = "out1")(out1)中的name ="out1",以及下单率输出层Dense(1, activation='sigmod',name = "out2")(out2)中的name ="out2"不能省略。

  • 第二个就是model.compile中的loss和loss的权重需要和任务输出层的name进行对应,如下:
    loss={'out1': loss,'out2': loss}
    loss_weights={'out1':1, 'crf_output': 1}

完整代码扫描下方二维码或微信搜索【有酒有风】关注回复【多任务】获取。

多任务学习(Multi-task)keras实现_第2张图片

 

你可能感兴趣的:(机器学习,深度学习,多任务学习,深度学习,机器学习,神经网络,人工智能)