利用bert4keras实现多任务学习

本文给出一个使用bert4keras实现多任务学习的例子

多任务学习

广义的讲,只要有多个loss就算Multi-task Learning。可以通过多个相关任务中的训练信息来提升模型的泛化性与表现

代码实现

其实多任务实现很简单,主要是在bert4keras框架的帮助下,如何去实现模型架构和设置输入输出。

我们以实现一个层级性多元标签文本分类(层级性多元标签是什么)。为例子

举个简单的例子,有一个电视产品,它属于“大家电”,也属于“家用电器”,而“大家电”标签是"家用电器"标签的子类,那么“家用电器”属于一级标签,“大家电” 属于二级标签,这产品所属种类标签是有层级结构。

我们的模型要实现两个任务 一个用来预测“家用电器”一级标签,一个用来预测“大家电”二级标签。

bert = build_transformer_model(config_path, checkpoint_path)

# 对文本的cls向量分别做两次多分类,得到两个输出level_1_output,level_2_output,分别计算loss。

level_1_cls = Lambda(lambda x: x[:, 0], name='level_1_CLS-token')(bert.output)
level_2_cls = Lambda(lambda x: x[:, 0], name='level_2_CLS-token')(bert.output)

level_1_output = Dense(len(level_1_category),
                       activation='softmax',
                       name='level_1_output')(level_1_cls)

level_2_output = Dense(len(level_2_category),
                       activation='softmax',
                       name='level_2_output')(level_2_cls)

model = Model(bert.inputs, [level_1_output, level_2_output])

losses = {
    "level_1_output": "categorical_crossentropy",
    "level_2_output": "categorical_crossentropy",
}
lossWeights = {"level_1_output": 1.0, "level_2_output": 1.0}

model.compile(
    loss=losses,
    optimizer=Adam(learning_rate),  # 用足够小的学习率
    loss_weights=lossWeights,
    metrics=['accuracy'],
)

对文本的cls向量分别做两次多分类,得到两个输出level_1_output,level_2_output,分别计算loss。

这里注意要对层进行命名,方便后面设置损失函数,调整损失权重。

如果多任务中的每一个的损失都相同,可以只写一个损失代替,不用每个都列出。

说完了模型结构,就要说数据的传入了

model = Model(bert.inputs, [level_1_output, level_2_output])

模型的输入是bert的输入,而输出的是两个label ,即一级标签的label和二级标签的label

class data_generator(DataGenerator):
    """数据生成器
    """

    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids, batch_labels, batch_2_labels, = [], [], [], []
        for is_end, (text_1, label_1, label_2) in self.sample(random):

            token_ids, segment_ids = tokenizer.encode(text_1, maxlen=maxlen)
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)

            batch_labels.append(label_1)
            batch_2_labels.append(label_2)

            if len(batch_token_ids) == self.batch_size or is_end:

                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)

                batch_labels = sequence_padding(batch_labels)
                batch_2_labels = sequence_padding(batch_2_labels)

                yield [batch_token_ids,
                       batch_segment_ids], [batch_labels, batch_2_labels]

                batch_token_ids, batch_segment_ids, batch_labels, batch_2_labels = [], [], [], []

主要注意yield [batch_token_ids,
                       batch_segment_ids], [batch_labels, batch_2_labels]处的处理就可以,其余按照情况调整即可。

最后给出一个相关的例子,可以参考

hgliyuhao/mixup (github.com)

你可能感兴趣的:(bert4keras,python,自然语言处理,学习)