【记录】使用transformers从头开始训练bert

【记录】使用transformers从头开始训练bert

这篇记录主要记录使用transformers库训练从头开始训练自己的bert预训练模型;

bert训练任务;

bert预训练模型包含两个任务:

  1. mask词预测
  2. 相邻句子预测

使用的API

使用的api为BertForPreTraining

from transformers import BertConfig, BertForPreTraining
# 构建模型
config = BertConfig(vocab_size=len(WORDS) + 1)
model = BertForPreTraining(config)
# 训练:
for epoch in range(200):
    for data in data_loader:
  
        next_sentence_label = data['next_sentence_label'].to(device).long()
        input_ids = data['input_ids'].to(device).long()
        token_type_ids = data['token_type_ids'].to(device).long()
        attention_mask = data['attention_mask'].to(device).long()
        labels = data['bert_label'].to(device).long()
        
        outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
                        labels=labels, next_sentence_label=next_sentence_label)
        loss = outputs['loss']
        
  		optim.zero_grad()
        loss.backward()
        optim.step()

步骤

1)构建数据集,bert的数据集不用标注,因为他的训练任务标签可以程序生成,使用爬取的文章就行;

2)构建字典,一般来说是基于数据集来构建字典,但要注意bert中预留的几个特殊字符(用来对应任务,这里当然也可以自定义,但还是建议按照原版的来),;或者你也可以直接用下载好的字典;

3)构建数据集,和训练数据;这是重点,注意数据的组织形式;

首先是针对预训练任务,

1)next sentence, 你需要一次性输入两个句子,标签为句子是否相邻,注意输入的时候正负均衡采样;输入拼接如下:

input_token: 【cls】+sentence1+【sep】+sentence2+【sep】
next_sentence_label: 1&0

2)mask语言模型:bert针对上述句子,随机maks或者替换掉15%的词,这部分标签,则为mask位置的单词的真实编号,需要注意的是,标签的其他位置用-100代替,如下示例; 操作完之后记得padding;

def random_word(sentence):
    """mask language model, 添加15%的mask"""

    tokens = [char for char in sentence]

    output_label = []
    for i, token in enumerate(tokens):
        prob = random.random()
        # 0.15的概率进行替换,包括mask, 错误,和正常
        if prob < 0.15:
            prob /= 0.15
            # 80% , 每一个词有80%概率(0.15*0.8)会用mask替代
            if prob < 0.8:
                tokens[i] = voc['[MASK]']
            # 10%, 10%的概率随机取值
            elif prob < 0.9:
                tokens[i] = random.randrange(len(voc))
            # 10%, 10%的概率正常token
            else:
                tokens[i] = voc.get(token, voc['[UNK]'])
            #
            output_label.append(voc.get(token, voc['[UNK]']))
        # 正常取值
        else:
            tokens[i] = voc.get(token, voc['[UNK]'])
            output_label.append(-100)
    return tokens, output_label

3)构建token_type_id和attention_mask;token_typeid用来区分第一句还是第二句,attention_mask记录非padding部分;格式如下;

输入组合:【cls】你好吗【sep】我很好啊【sep】【pad】...
input_token: [102],[15],[17],[19],[103][14],[12],[17],[20],[103][0]...
token_type_id: 0, 0,0,0, 0, 1,1,1,1,1[0]...
attention_mask: 1,1,1,1,1,1,1,1,1,1,0,...

需要注意两点:

句子1包含了头部的【cls】和一个【sep】,句子2包含最后一个【sep】,所以token_type_id要对应到上,不能错了位置;bert_label不包含这个标识符,标识符位置用【pad】

参考代码如下:

 def __getitem__(self, idx):
        t1, t2, is_next_label = self.get_sentence(idx)
        t1_random, t1_label = self.random_word(t1)
        t2_random, t2_label = self.random_word(t2)

        t1 = [self.vocab['[CLS]']] + t1_random + [self.vocab['[SEP]']]
        t2 = t2_random + [self.vocab['[SEP]']]
        t1_label = [self.vocab['[PAD]']] + t1_label + [self.vocab['[PAD]']]
        t2_label = t2_label + [self.vocab['[PAD]']]

        segment_label = ([0 for _ in range(len(t1))] + [1 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]

        padding = [self.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
        attention_mask = len(bert_input) * [1] + len(padding) * [0]
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)
        attention_mask = np.array(attention_mask)
        bert_input = np.array(bert_input)
        segment_label = np.array(segment_label)
        bert_label = np.array(bert_label)
        is_next_label = np.array(is_next_label)
        # 这里包含两个标签,一个是mask掉的词,另一个是上下文是否为相邻
        output = {"input_ids": bert_input,
                  "token_type_ids": segment_label,
                  'attention_mask': attention_mask,
                  "bert_label": bert_label}, is_next_label
        return output

你可能感兴趣的:(【记录】使用transformers从头开始训练bert)