本文主要介绍使用Prompt的MLM文本分类 bert4keras的代码实现,用以丰富bert4keras框架的例子 关于prompt的原理的文章网上有很多优秀的文章可以自行百度。
github地址
hgliyuhao/Prompt4Classification (github.com)
transformers,torch 实现的版本可以参考
基于Prompt的MLM文本分类_u013546508的博客-CSDN博客
剧烈运动后咯血,是怎么了? | 剧烈运动后为什么会咯血? | 1 |
剧烈运动后咯血,是怎么了? | 剧烈运动后咯血,应该怎么处理? | 0 |
每条数据是包含两句话 和一个标签 ,如果标签为1证明两句话意思相同 为0则证明意思不同
所以是一个分类任务
使用mlm模型,将分类任务转换成一个填空任务,再根据填空的结果决定分类结果。
根据任务和数据特点设计Prompt:两句话意思【mask】同。
所以对应模型的输入是 两句话意思【mask】同 :text1 ;text2
对应的label 如果标签为1 label 为‘相’,标签为0 label 为 ‘不’
因为在预测的时候是使用生成模型,为了确定mask结果的任务,所以Prompt要位于句子开头
由于需要使用mlm 要在引入模型的时候加入with_mlm=True
model = build_transformer_model(
config_path,
checkpoint_path,
with_mlm=True,
keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表
)
下面重点说一下输入输出格式
class data_generator(DataGenerator):
def __iter__(self, random=False):
"""单条样本格式为
输入:[CLS]两句话意思[MASK]同,text1,text2[SEP]
输出:'相'或者'不'
"""
idxs = list(range(len(self.data)))
if random:
np.random.shuffle(idxs)
batch_token_ids, batch_segment_ids, batch_a_token_ids = [], [], []
for i in idxs:
data = self.data[i]
text = "两句话意思相同"
text1 = data[0]
text2 = data[1]
label = data[2]
final_text = text + ':' + text1 + ',' + text2
token_ids, segment_ids = tokenizer.encode(final_text, maxlen=maxlen)
# mask掉'相'字
token_ids[6] = tokenizer._token_mask_id
if label == 0:
a_token_ids, _ = tokenizer.encode('不')
else:
a_token_ids, _ = tokenizer.encode('相')
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_a_token_ids.append(a_token_ids[1:])
if len(batch_token_ids) == self.batch_size or i == idxs[-1]:
batch_token_ids = sequence_padding(batch_token_ids)
batch_segment_ids = sequence_padding(batch_segment_ids)
batch_a_token_ids = sequence_padding(
batch_a_token_ids, 1
)
yield [batch_token_ids, batch_segment_ids], batch_a_token_ids
batch_token_ids, batch_segment_ids, batch_a_token_ids = [], [], []
token_ids[6] = tokenizer._token_mask_id 是将 "两句话意思相同" 转换成"两句话意思【mask】同"
batch_a_token_ids.append(a_token_ids[1:]) 这里(a_token_ids[1:] 是为了将cls 去掉
batch_a_token_ids = sequence_padding(batch_a_token_ids, 1) 这里是设置生成文本的长度,因为我们这个任务只需要预测结果为‘相’或者为‘不’所以 长度设置为1
句对模型 | prompt | |
acc | 0.93656 |
0.90414 |
prompt的效果是不如传统的句对模型的
但是prompt的思想是很有趣的,在few-shot 或者少样本的任务中是值得尝试的