Bert (Bi-directional Encoder Representations from Transformers) Pytorch 源码解读(三)

前言

Bert (Bi-directional Encoder Representations from Transformers) Pytorch 版本源码解读的第三篇,也是最后一部分。这一部分为源码中, wiki_dataset.py 文件中的内容,主要实现了 Bert 模型预训练时,数据的预处理工作。读完这一部分源码有助于更好的理解模型的输入部分的数据是如何构造的。


Bert 源码解读:

1. 模型结构源码: bert_model.py

2. 模型预训练源码:bert_training.py

3. 数据预处理源码:wiki_dataset.py


开始

1.初始化

class BERTDataset(Dataset):
    def __init__(self, corpus_path, word2idx_path, seq_len, hidden_dim=384, on_memory=True):
        # hidden dimension for positional encoding
        self.hidden_dim = hidden_dim
        # define path of dicts
        self.word2idx_path = word2idx_path
        # define max length
        self.seq_len = seq_len
        # load whole corpus at once or not
        self.on_memory = on_memory
        # directory of corpus dataset
        self.corpus_path = corpus_path
        # define special symbols
        self.pad_index = 0
        self.unk_index = 1
        self.cls_index = 2
        self.sep_index = 3
        self.mask_index = 4
        self.num_index = 5

        # 加载字典
        with open(word2idx_path, "r", encoding="utf-8") as f:
            self.word2idx = json.load(f)

        # 加载语料
        with open(corpus_path, "r", encoding="utf-8") as f:
            if not on_memory:
                # 如果不将数据集直接加载到内存, 则需先确定语料行数
                self.corpus_lines = 0
                for _ in tqdm.tqdm(f, desc="Loading Dataset"):
                    self.corpus_lines += 1

            if on_memory:
                # 将数据集全部加载到内存
                self.lines = [eval(line) for line in tqdm.tqdm(f, desc="Loading Dataset")]
                self.corpus_lines = len(self.lines)

        if not on_memory:
            # 如果不全部加载到内存, 首先打开语料
            self.file = open(corpus_path, "r", encoding="utf-8")
            # 然后再打开同样的语料, 用来抽取负样本
            self.random_file = open(corpus_path, "r", encoding="utf-8")
            # 下面是为了错位抽取负样本
            for _ in range(np.random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
                self.random_file.__next__()

Bert 数据处理的所有源码都封装在这个 BERTDataset 类中,首先 __init__ 方法主要用于设置一些参数,加载训练语料和字典,以及加载方式。同时设置了一些特殊的 token,分别为:

pad_index : 填充位标识

unk_index : 未登录词标识

cls_index : 句首标识

sep_index : 句尾标识

mask_index : MASK 标识

num_index : 数字标识

2.Mask

    def tokenize_char(self, segments):
        return [self.word2idx.get(char, self.unk_index) for char in segments]

    def random_char(self, sentence):
        char_tokens_ = list(sentence)
        char_tokens = self.tokenize_char(char_tokens_)

        output_label = []
        for i, token in enumerate(char_tokens):
            prob = random.random()
            if prob < 0.15:
                prob /= 0.15
                output_label.append(char_tokens[i])
                # 80% randomly change token to mask token
                if prob < 0.8:
                    char_tokens[i] = self.mask_index
                # 10% randomly change token to random token
                elif prob < 0.9:
                    char_tokens[i] = random.randrange(len(self.word2idx))
            else:
                output_label.append(0)
        return char_tokens, output_label

tokenizer_char 方法将句子中的字,转化为对应的字典中的 token,并将未登录词标的 token 设置为 unk_index 。

random_char 实现了 Mask 的过程,首先将句子中的字转换为对应的 token,其次通过 random 产生的 prob 来筛选15% 的字进行 mask 的操作。没有被选到字会直接将 label 置为0,这一部分在计算 MLM(Mask Language Model) 任务的 loss 时会被直接忽略,不参与 loss 的计算。剩下 15% 被选中的这些字中,有 80% 的概率会被将其 token 设置为 MASK,有 10% 的概率将其 token 设置为随机另一个词,10% 的概率不改变其原本的 token,同时将这 15% 的词的原本 token 加入label 序列,用于预测以及计算 loss 值。

3.Next Sentence Prediction

    def random_sent(self, index):
        t1, t2 = self.get_corpus_line(index)

        # output_text, label(isNotNext:0, isNext:1)
        if random.random() > 0.5:
            return t1, t2, 1
        else:
            return t1, self.get_random_line(), 0

    def get_corpus_line(self, item):
        if self.on_memory:
            return self.lines[item]["text1"], self.lines[item]["text2"]
        else:
            line = self.file.__next__()
            if line is None:
                self.file.close()
                self.file = open(self.corpus_path, "r", encoding="utf-8")
                line = self.file.__next__()
            line = eval(line)
            t1, t2 = line["text1"], line["text2"]
            return t1, t2

    def get_random_line(self):
        if self.on_memory:
            return self.lines[random.randrange(len(self.lines))]["text2"]

        line = self.random_file.__next__()
        if line is None:
            self.random_file.close()
            self.random_file = open(self.corpus_path, "r", encoding="utf-8")
            for _ in range(np.random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)):
                self.random_file.__next__()
            line = self.random_file.__next__()
        return eval(line)["text2"]

这一部分源码主要实现 Next Sentence Prediction 任务所需要的数据处理,random_sent 方法中,首先从语料中获取上下句的搭配,再根据概率值判断,有 50% 的概率将下一句替换为随机的另一句话,有 50% 的概率不改变原来上下句的搭配。替换为随机的另一句话时,label 为0,原始搭配 label 为1。

get_corpus_line 方法实现从语料中抽取句子的过程,get_random_line 方法实现从语料中随机抽取一句话。

4. 组合数据

    def __getitem__(self, item):
        t1, t2, is_next_label = self.random_sent(item)

        t1_random, t1_label = self.random_char(t1)
        t2_random, t2_label = self.random_char(t2)

        t1 = [self.cls_index] + t1_random + [self.sep_index]
        t2 = t2_random + [self.sep_index]

        t1_label = [self.pad_index] + t1_label + [self.pad_index]
        t2_label = t2_label + [self.pad_index]

        segment_label = ([0 for _ in range(len(t1))] + [1 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]

        output = {"bert_input": torch.tensor(bert_input),
                  "bert_label": torch.tensor(bert_label),
                  "segment_label": torch.tensor(segment_label),
                  "is_next": torch.tensor([is_next_label])}

        return output

__getitem__ 方法将之前的两部分数据处理的过程组合起来,首先进行句子层面的组合,再对组合起来的句子进行随机 MASK 的处理,然后生成 segment_label 与 bert_label ,最后将数据打包为处理好的最终数据格式。


总结

以上就是 Bert 数据预处理部分的全部代码,解析的顺序没有按照源码文件中的顺序进行,而是根据函数调用的逻辑来一一介绍,这样更有助于理解。

最后,Bert Pytorch 版本的源码解读也就全部结束,看过源码后确实可以更加深刻的了解模型结构以及训练过程。其中此版本的源码与 Google 官方开源的 TensorFlow 版本代码相比还是有一些细节上的差异,比如 Positional Encoding 的部分,但大体上还是与论文中的模型保持一致。

 

如有问题欢迎指正,转载请注明出处。  

你可能感兴趣的:(Python,NLP,Python,NLP,Bert,Pytorch)