本文给出一个使用bert4keras实现多任务学习的例子
广义的讲,只要有多个loss就算Multi-task Learning。可以通过多个相关任务中的训练信息来提升模型的泛化性与表现
其实多任务实现很简单,主要是在bert4keras框架的帮助下,如何去实现模型架构和设置输入输出。
我们以实现一个层级性多元标签文本分类(层级性多元标签是什么)。为例子
举个简单的例子,有一个电视产品,它属于“大家电”,也属于“家用电器”,而“大家电”标签是"家用电器"标签的子类,那么“家用电器”属于一级标签,“大家电” 属于二级标签,这产品所属种类标签是有层级结构。
我们的模型要实现两个任务 一个用来预测“家用电器”一级标签,一个用来预测“大家电”二级标签。
bert = build_transformer_model(config_path, checkpoint_path)
# 对文本的cls向量分别做两次多分类,得到两个输出level_1_output,level_2_output,分别计算loss。
level_1_cls = Lambda(lambda x: x[:, 0], name='level_1_CLS-token')(bert.output)
level_2_cls = Lambda(lambda x: x[:, 0], name='level_2_CLS-token')(bert.output)
level_1_output = Dense(len(level_1_category),
activation='softmax',
name='level_1_output')(level_1_cls)
level_2_output = Dense(len(level_2_category),
activation='softmax',
name='level_2_output')(level_2_cls)
model = Model(bert.inputs, [level_1_output, level_2_output])
losses = {
"level_1_output": "categorical_crossentropy",
"level_2_output": "categorical_crossentropy",
}
lossWeights = {"level_1_output": 1.0, "level_2_output": 1.0}
model.compile(
loss=losses,
optimizer=Adam(learning_rate), # 用足够小的学习率
loss_weights=lossWeights,
metrics=['accuracy'],
)
对文本的cls向量分别做两次多分类,得到两个输出level_1_output,level_2_output,分别计算loss。
这里注意要对层进行命名,方便后面设置损失函数,调整损失权重。
如果多任务中的每一个的损失都相同,可以只写一个损失代替,不用每个都列出。
说完了模型结构,就要说数据的传入了
model = Model(bert.inputs, [level_1_output, level_2_output])
模型的输入是bert的输入,而输出的是两个label ,即一级标签的label和二级标签的label
class data_generator(DataGenerator):
"""数据生成器
"""
def __iter__(self, random=False):
batch_token_ids, batch_segment_ids, batch_labels, batch_2_labels, = [], [], [], []
for is_end, (text_1, label_1, label_2) in self.sample(random):
token_ids, segment_ids = tokenizer.encode(text_1, maxlen=maxlen)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_labels.append(label_1)
batch_2_labels.append(label_2)
if len(batch_token_ids) == self.batch_size or is_end:
batch_token_ids = sequence_padding(batch_token_ids)
batch_segment_ids = sequence_padding(batch_segment_ids)
batch_labels = sequence_padding(batch_labels)
batch_2_labels = sequence_padding(batch_2_labels)
yield [batch_token_ids,
batch_segment_ids], [batch_labels, batch_2_labels]
batch_token_ids, batch_segment_ids, batch_labels, batch_2_labels = [], [], [], []
主要注意yield [batch_token_ids,
batch_segment_ids], [batch_labels, batch_2_labels]处的处理就可以,其余按照情况调整即可。
最后给出一个相关的例子,可以参考
hgliyuhao/mixup (github.com)