BERT代码逐行逐句详解版(pytorch版本)

最近入门BERT,在网上观看了一些网课视频理解了原理,并且找到了pytorch版本的源码,经过一遍阅读有了初步的认知,所以在此记录,温故而知新。

本文所解读的源码链接为:https://github.com/daiwk/BERT-pytorch/tree/master/bert_pytorch

其整体代码框架如下(有些部分我也略有改动,但整体不影响):

BERT代码逐行逐句详解版(pytorch版本)_第1张图片

 解读一个项目的代码,自然要从main开始,所以我们打开main.py(项目中是__main__.py)后看到首先是对一些路径参数的填写:

BERT代码逐行逐句详解版(pytorch版本)_第2张图片

 我个人的上述自个的参数为

--train_dataset ./corpus/train.tsv --test_dataset ./corpus/test.tsv --vocab_path ./vocab/vocab.txt --output_path output/bert.model

其中train_dataset和test_dataset是指你选的任务的训练数据和测试数据,我们一般称之为corpus(语料库)这里我们选取了GLUE数据集中的MRPC任务的训练集和测试集。而vocab_path指的是vocabulary库(词汇表库),它相当于一个大字典,记录了所有可能出现的单词,后边我们将语料库中的单词转为id时候需要在这个大字典里查找(这个vocab.txt可以去huggingface上找,GLUE数据集网上有许多好心人分享了网盘,但是需要注意的是不同的任务数据的样式是不同的,所以处理起来是有差别的!!比如说有的数据是一个类+一个句子+一个句子,有的是一个类+一个句子)

接下来我将以main函数中每行代码作为出发点,假设大家可以跟我一起跳转到对应文件(我会用灰色方框标注路径,大家可以跟着跳转)

main.py

 

这块是将你的单词表从你的txt格式转换成一个它对应的python对象以便它后边处理,点进去看:

dataset / vocab.py

BERT代码逐行逐句详解版(pytorch版本)_第3张图片

但是我在运行的时候出了一些问题,pickle在读的时候总是说非法字符。

main.py

BERT代码逐行逐句详解版(pytorch版本)_第4张图片于是我先用save_vocab再load_vocab成功了。然后是读入corpus的训练集和测试集,这里只看训练集:

 然后跳转进入BERTDataset (这种loader都是调用__getitem__方法实现的)

dataset / dataset.pyBERT代码逐行逐句详解版(pytorch版本)_第5张图片

这一部分完成的事情就是为NSP和MLM,刚开始我直接把MRPC的数据放进去发现和预想的不一样,后来通过细读这部分代码发现:这部分数据处理的代码默认了我们的语料库中每一行是有两个句子并且以 '\t' 分隔开。所以它的处理步骤:

(1)把一行中两个句子拿出来,然后以50%概率正常返回两个句子,并且label返回为1(代表这两个句子是连接在一起的,否则就随机选一个别的句子拼上去并label为0代表不是一起的,这一步可以为NSP任务处理)

BERT代码逐行逐句详解版(pytorch版本)_第6张图片

 BERT代码逐行逐句详解版(pytorch版本)_第7张图片

 (2)把这两个句子中的一个个词扣出来,按照论文中所讲的概率将单词变成[MASK]  其他词 原来的词的id(id就是指这个词在大字典中的序号,因为计算机读入时候想读入数字)。

 15%是变化过的,它的label要登记上这个词应该有的id作为目标,剩下没变的label直接变成0,因为想忽视这些单词集中注意搞那15%,在损失函数中会直接乘0消失BERT代码逐行逐句详解版(pytorch版本)_第8张图片

BERT代码逐行逐句详解版(pytorch版本)_第9张图片

(3)给两个句子的token加上头CLS和尾巴SEP,同时两者对应的label也是0

BERT代码逐行逐句详解版(pytorch版本)_第10张图片

 (4)分段id,以及补全的padding id

BERT代码逐行逐句详解版(pytorch版本)_第11张图片

上面是对数据的预处理部分,包括了对NSP和MLM任务分别的数据处理,以及加上CLS和SEP,还有对应的segment id等一些操作,下面正式开始构建BERT的整体框架。

main.py

 

model / bert.py

BERT代码逐行逐句详解版(pytorch版本)_第12张图片

 首先是BERT模型的初始化,我们知道BERT是Transformer的encoder部分,所以初始化这部分主要是对Transformer的结构的搭建(包括Transformer的block和多头的头数,hiddensize的大小,layer的层数,上述注释有提到)

然后看forward函数可以观察:

首先构建mask,找到数据中不为padding的地方给赋值为1并扩展成对应的维度

 

使用BERTEmbedding将单词编码

经过一个个Transformer块最后输出

上述BERTEmbedding 和 Transformer 介绍如下:

BERTEmbedding在 model / embedding / bertEmbedding.py

BERT代码逐行逐句详解版(pytorch版本)_第13张图片

 分别对三种embedding分别介绍引入(引用Bert细节整理 - 简书):

BERT代码逐行逐句详解版(pytorch版本)_第14张图片

Transformer块的代码在model / transformer.py / TransformerBlock

BERT代码逐行逐句详解版(pytorch版本)_第15张图片

 这里就是Transformer的encoder结构图中的四部分对应代码,可参考如图:

BERT代码逐行逐句详解版(pytorch版本)_第16张图片

最后就是常规的pytorch训练网络的训练过程,放进数据,并和Label计算Loss并反向传播

main.py

 

经过串下来后,我对开始的结构图标注了各自的作用,以便回顾

BERT代码逐行逐句详解版(pytorch版本)_第17张图片

忙活了一个下午的张师傅准备休息了~ 有错的话回头再改哇 

 

你可能感兴趣的:(正式开始炼丹,pytorch,bert,深度学习)