pytorch版本的BERT的源码链接
我将会为大家梳理代码,解读代码。并提出自己的一些见解
基本注释已经穿插在代码块中,另外一些看法单独拿出来说
#导包
import pickle
import tqdm
from collections import Counter
class TorchVocab(object):
def __init__(self,counter,max_sizeNone,min_freq=1,specials=['' ,'' ],
vectors=None,unk_init=None,vectors_cache=None):
"""
counter:
类型:collections.Counter对象
描述:用于存储数据中每个单词出现的频率。这个计数器对象通常是通过在数据集上进行词频统计得到的。
max_size:
类型:int或None
默认值:None
描述:词汇表的最大大小。如果设置为None,则不限制词汇表的大小。如果设置为一个具体的数字,词汇表的大小将限制为最多包含这么多单词。
min_freq:
类型:int
默认值:1
描述:包含在词汇表中的最小频率阈值。频率低于这个阈值的单词将不会被包括在词汇表中。小于1的值将被设置为1。
specials:
类型:字符串列表
默认值:['', '']
描述:一组特殊标记,它们将被添加到词汇表中。这些特殊标记通常包括用于填充()、表示未知单词(或)等特殊用途的标记。
vectors:
类型:预训练向量,可以是字符串列表
默认值:None
描述:用于指定预训练的词向量。这可以是预训练向量的名称,或者是用户自定义的预训练向量。
unk_init:
类型:回调函数
默认值:torch.Tensor.zero_
描述:用于初始化未知单词(OOV,out-of-vocabulary)向量的函数。默认情况下,未知单词的向量被初始化为零向量。
vectors_cache:
类型:字符串
默认值:'.vector_cache'
描述:用于缓存预训练向量的目录路径。如果提供了向量路径,则预训练向量将被下载或加载到这个目录中。
"""
self.freqs=counter
counter=counter.copy()
min_freq=max(min_freq,1)
self.itos=list(specials)
for tok in specials:
del counter[tok]
max_size=None if max_size is None else max_size+len(self.itos)
#按频率排序,然后按字母顺序排序
words_and_frequencies = sorted(counter.items(),key=lambda tup:tup[0])#按字母顺序排序
words_and_frequencies.sort(key=lambda tup:tup[1],reverse=True)#再按照频率从大到小排序
for word,freq in words_and_frequencies:
if freq<min_freq or len(self.itos)==max_size:
break
self.itos.append(word)
#stoi就是itos的翻转字典
self.stoi={tok:i for i,tok in enumerate(self.itos)}
self.vectors=None
if vectors is not None:
self.load_vectors(vectors,unk_init=unk_init,cache=vectors_cache)
else:
assert unk_init is None and vectors_cache is None
def __eq__(self,other):#判断两个字典是否相等
if self.freqs!=other.freqs:
return False
if self.stoi!=other.stoi:
return False
if self.itos!=other.itos:
return False
if self.vectors!=other.vectors:
return False
return True
def __len__(self):
return len(self.itos)
def vocab_rerank(self):#再次翻转为stoi
self.stoi={word: i for i,word in enumerate(self.itos)}
def extend(self,v,sort=False):#将另一个字典加入到当前字典中
words=sorted(v.itos) if sort else v.itos
for w in words:
if w not in self.stoi:
self.itos.append(w)
self.stoi[w]=len(self.itos)-1
class Vocab(TorchVocab):
def __init__(self,counter,max_size=None,min_freq=1):
self.pad_index=0
self.unk_index=1
self.eos_index=2
self.sos_index=3
self.mask_index=4
super().__init__(counter,specials=["" ,"" ,"" ,"" ,"" ],max_size=max_size,min_freq=min_freq)
def to_seq(self,sentence,seq_len,with_eos=False,with_sos=False)->list:#字符序列转换为数字序列
pass
def from_seq(self,seq,join=False,with_pad=False):#数字序列转换为字符序列
pass
@staticmethod
def load_vocab(vocab_path:str)->'Vocab':#静态方法,利用pickle将字符加载出来
with open(vocab_path,'rb') as f:
return pickle.load(f)
def save_vocab(self,vocab_path):#利用pickle保存数据
with open(vocab_path,'wb' ) as f:
pickle.dump(self,f)
class WordVocab(Vocab):
def __init__(self,texts,max_size=None,min_freq=1):
print("Building Vocab")
counter=Counter()
for line in tqdm.tqdm(texts):
if isinstance(line,list):
words=line
else:
words=line.replace("\n","").replace("\t","").split()
for word in words:
counter[word]+=1
super().__init__(counter,max_size=max_size,min_freq=min_freq)
def to_seq(self,sentence,seq_len=None,with_eos=False,with_sos=False,with_len=False):
if isinstance(sentence,str):
sentence=sentence.split()
if with_eos:#如果with_eos为真,意味着需要在序列末尾添加结束标记
seq+=[self.eos_index]
if with_sos:#如果with_sos为真,意味着需要在序列开头添加开始标记。
seq=[seq.sos_index]+seq
origin_seq_len=len(seq)#记录转换后序列的长度
if seq_len is None:
pass
elif len(seq)<=seq_len:
seq+=[self.pad_index]*(seq_len-len(seq))
else:
seq=seq[:seq_len]
return (seq,origin_seq_len) if with_len else seq #如果 with_len 为 True,函数将返回一个元组,其中origin_seq_len为不加其他标识符之前的句子
def from_seq(self,seq,join=False,with_pad=False):
"""
Args:
seq:一个整数列表,通常表示单词的索引序列。
join:一个布尔值,指示是否应该将转换后的单词列表拼接为一个字符串。
with_pad:一个布尔值,指示是否应该在输出中包含填充(PAD)标记。
"""
words=[self.itos[idx]
if idx<len(self.itos)
else "<%d>"% idx
for idx in seq if not with_pad or idx!=self.pad_index]#当with_pad为True时,所有的索引对应的单词都将加载到words中
return "".join(words) if join else words
@staticmethod
def load_vocab(vocab_path:str)->'WordVocab':#,用于从指定路径加载一个之前保存的WordVocab对象。
with open(vocab_path,"r") as f:
return pickle.load(f)
def build():
import argparse
parser = argparse.ArgumentParser()
"""
-c 或 --corpus_path:指定语料库文件的路径。
-o 或 --output_path:指定输出词汇表文件的路径。
-s 或 --vocab_size:指定词汇表的最大大小。
-e 或 --encoding:指定文件的编码格式,默认为utf-8。
-m 或 --min_freq:指定单词在语料库中出现的最小频率。
"""
parser.add_argument("-c", "--corpus_path", required=True, type=str)
parser.add_argument("-o", "--output_path", required=True, type=str)
parser.add_argument("-s", "--vocab_size", type=int, default=None)
parser.add_argument("-e", "--encoding", type=str, default="utf-8")
parser.add_argument("-m", "--min_freq", type=int, default=1)
args = parser.parse_args()
with open(args.corpus_path, "r", encoding=args.encoding) as f:
vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq) #读取语料库并构建词汇表
print("VOCAB SIZE:", len(vocab)) #打印词汇表的大小(即其中包含的单词数量)
vocab.save_vocab(args.output_path) #调用vocab.save_vocab 方法将词汇表保存到指定的输出文件路径
可以看到vocab.py包含三个类,这三个类从上往下依次继承
TorchVocab:
这个类是基础类,定义了词汇表的核心功能,如建立词汇表、索引到字符串(itos)和字符串到索引(stoi)的映射,处理预训练向量等。
它包含了初始化方法(init),用于从词频计数器创建词汇表,以及一些基本方法,如比较两个词汇表是否相等(eq)、获取词汇表长度(len)、重新排列词汇表(vocab_rerank)和扩展当前词汇表(extend)。
Vocab:
Vocab类继承自TorchVocab。它在TorchVocab的基础上增加了特定的特殊标记,如填充pad、未知单词unk、句末eos、句首sos和掩码mask。
Vocab类提供了额外的方法,如to_seq和from_seq(这两个方法在提供的代码中没有具体实现,而是放在之后的WordVocab中实现了),以及用于保存和加载词汇表的方法。
WordVocab:
WordVocab类继承自Vocab。它专门用于从文本数据构建词汇表。这个类的构造函数接受文本数据,计算单词频率,并使用这些信息来构建词汇表。
它重写了to_seq和from_seq方法,这些方法用于将句子转换为索引序列,以及将索引序列转换回文本。
这三个类构成了一个层次结构,其中每个子类在其父类的基础上增加了更具体的功能。TorchVocab提供了基本的词汇表功能,Vocab在此基础上增加了对特殊标记的处理,而WordVocab则是专门针对从文本数据构建和使用词汇表的场景。这种设计使得代码更加模块化和可重用,同时也提供了灵活性,以便于根据特定需求进行扩展或修改。
from torch import Dataset
import tqdm
import torch
import random
class BERTDataset(Dataset):
def __init__(self,corpus_path,vocab,seq_len,encoding="utf-8",corpus_lines=None,on_memory=True):
self.vocab=vocab
self.seq_len=seq_len
self.on_memory=on_memory
self.corpus_lines=corpus_lines
self.corpus_path=corpus_path
self.encoding=encoding
with open(corpus_path,"r",encoding=encoding) as f:
if self.corpus_lines is None and not on_memory:
for _ in tqdm.tqdm(f,desc="Loading Dataset",total=corpus_lines):
self.corpus_lines+=1
if on_memory:
self.lines=[line[:-1].split("\t") for line in tqdm.tqdm(f,desc="Loading Dataset",total=self.corpus_lines)]
self.corpus_lines=len(self.lines)
if not on_memory:
self.file=open(corpus_path,"r",encoding=encoding)
self.random_file=open(corpus_path,"r",encoding=encoding)
for _ in range(random.randint(self.corpus_lines if self.corpus_lines<1000 else 1000)):
self.random_file.__next__()
#返回语料的长度
def __len__(self):
return self.corpus_lines
def __getitem__(self,item):
t1,t2,is_next_label=self.random_sent(item)
t1_random,t1_label=self.random_word(t1)
t2_random,t2_label=self.random_word(t2)
t1=[self.vocab.sos_index]+t1_random+[self.vocab.eos_index]#sos是文本开始标签 eos是文本结束也是分割标签
t2=t2_random+[self.vocab.eos_index]
#构建掩码标签
t1_label=[self.vocab.pad_index]+t1_label+[self.vocab.pad_index]
t2_label=t2_label+[self.vocab.pad_index]
#生成段落标签
segment_label=([1 for _ in range(len(t1))]+[2 for _ in range (len(t2))])[:self.seq_len]
bert_input=(t1+t2)[:self.seq_len]
bert_label=(t1_label+t2_label)[:self.seq_len]
padding=[self.vocab.pad_index for _ in range(self.seq_len-len(bert_input))]
bert_input.extend(padding),bert_label.extend(padding),segment_label.extend(padding)
output={"bert_input":bert_input,"bert_label":bert_label,"segment_label":segment_label}
return {key:torch.tensor(value) for key ,value in output.items()}
def random_word(self,sentence):
tokens=sentence.split()#将每个单词分词
output_label=[]
for i,token in enumerate(tokens):
prob=random.random()
if prob<0.15:
prob/=0.15 #相对概率
#80% randomly change token to mask token
if prob<0.8:
tokens[i]=self.vocab.mask_index
# 10% randomly change token to mask token
elif prob<0.9:
tokens[i]=random.randrange(len(self.vocab))#返回一个随机的单词在词典中的索引
# 10% randomly change token to current token
else:
tokens[i]=self.vocab.stoi.get(token,self.vocab.unk_index)#返回本单词在词典中的索引
output_label.append(self.vocab.stoi.get())#单词被替换,那么output_label需要记录ground_truth,即记录被替换的单词在词典中的索引
else:
tokens[i]=self.vocab.stoi.get(token,self.vocab.unk_index)
output_label.append(0)
return tokens,output_label
"""
内存中读取 (self.on_memory 为 True):
当 self.on_memory 为 True 时,函数直接从 self.lines 列表中返回索引为 item 的行,其中 self.lines[item][0] 和 self.lines[item][1] 分别代表该行的两个文本片段。
文件中读取 (self.on_memory 为 False):
如果 self.on_memory 为 False,函数尝试从打开的文件对象 self.file 中读取下一行。
如果读取的行是 None(意味着到达了文件的末尾),它会关闭文件,重新打开文件,然后再次读取下一行。
读取的行假定为一个用制表符(\t)分隔的两部分文本,这两部分文本通过 split("\t") 方法分割,并赋值给 t1 和 t2。
"""
#"I like to eat pizza."\t"It is delicious."
# t1 t2
def get_corpus_line(self,item):
if self.on_memory:
return self.lines[item][0],self.lines[item][1]
else:
line=self.file.__next__()
if line is None:
self.file.close()
self.file=open(self.corpus_path,"r",encoding=self.encoding)
line=self.file.__next__()
t1,t2=line[:-1].split("\t")
return t1,t2
#它的作用是从文本数据中随机选择并返回一行。它处理两种情况:数据全部加载到内存中,或者从文件中逐行读取
def get_random_line(self):
if self.on_memory:
return self.line[random.randrange(len(self.lines))][1]
line=self.file.__next__()
if line is None:
self.file.closee()
self.file=open(self.corpus_path,"r",encoding=self.encoding)
for _ in range(random.randint(self.corpus_lines if self.corpus_lines <1000 else 1000)):
self.random_file.__next__()
line=self.random_file.__next__()
return line[:-1].split("\t")[1]
#创建一个训练样本,其中包括从数据集中选定的文本行(可能是一个句子或者段落),以及一个标签,标识这个文本行是否是与另一个文本行“下一个”相关联(即是否是连续的文本行)
def random_sent(self,index):
t1,t2=self.get_corpus_line(index)
if(random.random()>0.5):
return t1,t2,1
else:
return t1,self.get_random_line(),0
BERTDataset 类的核心功能是为训练 BERT 模型准备和处理数据。它处理文本数据,生成适合 BERT 训练的输入,包括掩码语言模型的输入和下一个句子预测任务的输入。通过随机替换、生成掩码和标签,这个类为 BERT 模型的预训练过程提供了必要的数据处理功能。
from .dataset import BERTDataset
from .vocab import WordVocab
这段代码就是象征性地引入两个上面提到的关键模块。
了解一个模型,不能只了解他的核心架构,如果你有代码实现的需求,就必须把这些数据处理的部分掌握好,这些是模型运行的基础,更是成功的根基。