BERT源码分析之数据预处理部分

我们以Mrpc任务来分析源码,Mrpc任务是判断两个句子是不是一个意思,

从run_classifier.py开始

首先定位到296行,找到类MrpcProcessor,所有预处理的类都继承自DataProcessor,这个类在177行,有一个read_tsv方法,

#我对这个方法适当的修改了,尽量避免大家没见过的方法,例如tf.gfile.Open 可以用open代替,csv.reader(delimiter="\t"),也完全可以用readlines()+split("\t)代替
def read_tsv(input_file,seperated="\t"):
	#seperated就是delimiter,quotechar是指对于双引号的区域也要引用,这个可以忽略
	with open(input_file) as f:
		lines=f.readlines()
	#lines是一个列表,每一行是一个字符串,也就是input_file的每一行文本
	output=[]
	for line in lines:
		output.append(line.strip().split(seperated))
		#seperated由具体文件内的文本格式决定
	return output

这个类还有四个方法是需要一一实现的,回到296行
我们可以看到四个方法一一实现了,前三个是为了得到examples.通过self.read_tsv方法我们知道返回的是一个长度和文件行数一样长的list,每一个元素也是一个list,是文件中的每一行文本字符串经过"\t"分割后的list。接下来将这个list送入create_examples方法生成examples,现在看create_examples,在318行:

#在进入create_examples之前还要知道tokenization中的一个方法convert_to_unicode,在tokenization文件中的78行
def convert_to_unicode(text):
	#注意传进来的text是一句话或者一个单词,也就是一个字符串
	#这个函数也可以适当的修改,因为现在肯定都是python3,而且python3默认的编码方式就是utf-8,所以这个函数甚至可以不用
	#但是由于后面经常引用这个函数,还是把它写上吧
	if isinstance(text,str):
		return text#text在python3中就是str形式
	elif isinstance(text,bytes):
		return text.decode("utf-8","ignore")#由bytes转到utf-8要调用decode,由utf-8转到bytes要调用encode
	else:
		raise ValueError("Unsupported string type: %s "%(type(text)))

#知道convert_to_unicode怎么回事了下面回到run_classifier.py的318行 

def create_examples(self,lines,set_type):
	examples=[]
	for (i,line) in enumerate(lines):
		if i==0:
			continue#注意,官方Mrpc任务中给出的数据集的第一行不是训练的文本
			#传进来的lines的第一行是这样的['\ufeffQuality', '#1 ID', '#2 ID', '#1 String', '#2 String']
			#所以i==0时continue
		guid="%s-%s" % (set_type,i)#set_type是"train","dev","test"中的某一个
		text_a=tokenization.convert_to_unicode(line[3])
		text_b=tokenization.convert_to_unicode(line[4])
		#把句子1和句子2拿出来
		if set_type=="test":
			label="0"#测试时不要标签
		else:
			label=tokenization.convert_to_unicode(line[0])	
		examples.append(InputExample(guid,text_a,text_b=text_b,label=label))
	return examples
#这个InputExample在127行,就是一个类,甚至没有方法。

现在定位到main方法,在783行,注意接下来我写的代码中适当的略去了源码中main方法的某些行代码,略去的代码不影响我们分析源码

processors={"mrpc": MrpcProcessor,}#其他的就不写了
task_name=FLAGS.task_name.lower()#这个就是我们在命令行中输入的任务名称,从这可以看出来命令行中输入的字符大小写均可。
processor=processors[task_name]()#相当于processor=MrpcProcessor()
label_list=processor.get_labels()#['0','1']
#注意我们现在所在的行数是817行,下面的tokenizer带着我们跑到了另一个文件tokenization.py

我们接下来就进入到tokenization.py中,我们最终是要知道FullTokenizer是什么,不过之前需要对里面的函数逐个分析

#首先说明一下源码中随处可见的unicodedata
unicodedata.categorty(char)#返回一个字符在unicode里的类别,源码中用到的类别有
'''
[Cc]Other,Control
[Cf]Other,Format
[Mn]Mark,Nonspacing
[Nd]Number,Digit
[Po]Punctuation,Other
[Zs]Separator,Space
'''
#例如unicodedata.category(char)如果是Zs的话,那就意味着这个字符是分隔符或者空格
#unicodedata.category(char)返回的类别如果是以P开头的话,那么说明这个字符是punctuation(标点符号的意思)
#unicodedata.category(char)如果是Cc,或者Cf,那么说明char是控制字符(我不了解什么是控制字符)
unicodedata.normalize(form,unistr)#将unicode编码形式的unistr转成普通格式的字符串
#normalize主要是解决那些特别奇怪的字符,像平时我们见到的无论是中文还是英文,没有说需要normalize的。
def is_whitespace(char):#362行
	#\t,\r,\n是控制字符,但是源码中把它们视为是空白字符
	if char==" " or char =="\t" or char=="\r" or char=="\n":
		return True
	if unicodedata.category(char)=="Zs":
		return True#"Zs" 对应于separator,space
	return False

def is_control(char):
	#注意\t \r \n源码中视为是空白字符
	if char=="\t" or char=="\r" or char=="\n":
		return False
	if unicodedata.category(char) in ("Cc","Cf"):
		return True
	return False

def is_punctuation(char):
	#这个就不写了,判断是不是标点符号的

上面三个函数就是用来判断一个字符是空格,分隔符,换行,标点符号,控制字符的哪一类。

#定位到121行load_vocab
def load_vocab(vocab_file):#这个vocab_file就是我们下载的预训练模型中的vocab.txt
	vocab=collections.OrderedDict()#和{}的区别就是这个有序,我也不知道换成普通的{}行不行
	index=0
	with open(vocab_file) as f:
		lines=f.readlines()
	for line in lines:
		token=convert_to_unicode(line)
		if not token:
			break
		vocab[token.strip()]=index
		index+=1
	return vocab#这个函数的作用就是建立vocab.txt中每一个词到对应的id的一个词典

#152行
def whitespace_tokenize(text):
	text=text.strip()
	if not text:
		return []
	tokens=text.split()
	return tokens
#whitespace_tokenize函数就是一句话text.strip().split(),也就是将一行字符串按照空格分割

FullTokenizer由两部分组成BasicTokenizer+WordPieceTokenizer

class BasicTokenizer(object):
	def __init__(self,do_lower_case=True):
		self.do_lower_case=do_lower_case
	def clean_text(self,text):
		#text是一句话,clean_text就是将这句话中的\t,\r,\n替换成空格,对于其它无效字符或者控制字符直接去掉
		output=[]
		for char in text:
			if ord(char)==0 or ord(char)==0xfffd or is_control(char):
				continue#ord()就是将字符转成对应的整数值,例如ord('a')=97
			if is_whitespace(char):
				output.append(' ')#\t \r \n用空格代替
			else:
				output.append(char)
		return "".join(output)#返回的是字符串
	def tokenize_chinese_chars(self,text):
		output=[]
		for char in text:
			cp=ord(char)
			if self.is_chinese_char(cp):
				output.append(" ")
				output.append(char)
				output.append(" ")
			else:
				output.append(char)
		return "".join(output)
	'''
	说明一下,假如输入的句子是"处理 中文 是 按照 字 来 处理 的"
	那么tokenize_chinese_chars的输出是' 处  理 \u3000 中  文 \u3000 是 \u3000 按  照 \u3000 字 \u3000 来 \u3000 处  理 \u3000 的 '
	\u3000就是中文的空格
	再经过whitespace_tokenize后就变成了
	['处', '理', '中', '文', '是', '按', '照', '字', '来', '处', '理', '的']
	这时看一下196行tokenize方法就明白流程了
	text=convert_to_unicode(text)#通常这个函数用不上
	text=clean_text(text)#将\t \r \n等换成空格,控制字符等去掉
	text=tokenize_chinese_chars(text)#针对中文的处理
	'''
	def run_strip_accents(self,text):
		#accents是重音符号的意思,这个貌似中文用不上,其它的语言有重音符号的现象
		text=unicodedata.normalize("NFD",text)#normalize的作用就是将text转换成普通字符
		output=[]
		for char in text:
			if unicodedata.category(char)=="Mn":
				continue#Mn 指的是Mark nonspace
			output.append(char)
		return "".join(output)
	
	def run_split_on_punc(self,text):
		#将text中的标点符号与单词分离,
		#输入是"Splited the sentence, with punctuations."
		#输出是['Splited the sentence', ',', ' with punctuations', '.']
		chars=list(text)#所有单个字符组成的列表
		start_new_word=True
		outputs=[]
		for i in range(len(chars)):
			char=chars[i]
			if is_punctuation(char):
				outputs.append([char])
				start_new_word=True
			else:
				if start_new_word==True:
					outputs.append([])#如果开始一个新的单词,那么在output中加入一个[]
				outputs[-1].append(char)#这个单词的每一个字符就会相继的加入到这个[]中
				start_new_word=False
		return ["".join(x) for x in outputs]
		#注意观察211-215行就会发现,其实输入给run_split_on_punc的是单词,而不是句子,假如输入的单词没有标点符号那么这个函数就没什么作用
		#输入的单词是单词如果是"punctuations."
		#输出就是["punctuations","."]
	#现在我们来看tokenize函数196行,一定要知道whitespace_tokenize(text)函数就是一句代码text.strip().split()
	def tokenize(self,text):
		#假设输入的是"This is basic, tokenizer.\n"
		text=convert_to_unicode(text)#This is basic, tokenizer.\n
		text=self.clean_text(text)#This is basic, tokenizer.
		text=self.tokenize_chinese_chars(text)#This is basic, tokenizer.
		orig_tokens=whitespace_tokenize(text)#["This","is","basic,","tokenizer."]
		split_tokens=[]
		for token in orig_tokens:
			if self.do_lower_case:
				token=token.lower()
				token=self.run_strip_accents(token)
			split_tokens.extend(self.run_split_on_punc(token))
	#split_tokens==["this","is","basic",",","tokenizer","."]
	output_tokens=whitespace_tokenize(" ".join(split_tokens))
	return output_tokens

所以说BasicTokenize所做的就是将一个句子,去掉特殊字符,控制字符,\t,\r,\n等字符,以及将带有标点符号的单词与标点符号分离,最后返回一个列表,每一个元素值就是一个token.
下面来看WordpieceTokenizer,定位300行

class WordpieceTokenize(object):
	def __init__(self,vocab,unk_token="[UNK]",max_input_chars_per_word=200):
		self.vocab=vocab
		#vocab就是load_vocab返回的vocab
		self.max_input_chars_per_word=max_input_chars_per_word
		self.unk_token=unk_token

	def tokenize(self,text):
		#输入"this is in wordpiece tokenizer"
		#输出['this', 'is', 'in', 'word', '##piece', 'token', '##izer']
		text=convert_to_unicode(text)#this is in wordpiece tokenizer
		output_tokens=[]
		token_list=whitespace_tokenize(text)#[this,is,in,wordpiece,tokenizer]
		for token in token_list:
			char=list(token)#['w','o','r','d','p','i','e','c','e']
			if len(chars)>self.max_input_chars_per_word:
				output_tokens.append(self.unk_token)
				continue#这个基本用不上,一个单词,怎么可能有200个字符那么长,如果真这么长的话,就用UNK代替
			is_bad=False
			start=0
			sub_tokens=[]#sub_tokens记录的是一个单词token会有多少个子单词
			
			#正向最大匹配,假设现在token是wordpiece,len(chars)==9
			while(start<len(chars)):
				end=len(chars)
				cur_substr=None#先假设这个单词没有子单词,如果在下面的while循环中找到子单词,就赋值给cur_substr
				while(start<end):#从最后一个字符逐步往前找子单词
					substr="".join(chars[start:end])#第一次时cur_substr=wordpiece
					if start>0:
						substr="##"+substr
					if substr in self.vocab:
						cur_substr=substr
						break
					end-=1#当end=4时,chars[0:4]==substr=="word",此时找到了子串,cur_substr="word" break
				if cur_substr is None:#也就是说无论是整个单词还是子串都没有在vocab中出现过
					is_bad=True
					break#跳出循环
				#此时cur_substr=="word"
				sub_tokens.append(cur_substr)#将找到的子字符串加入到sub_tokens,然后start=end,接着找,会找到piece,而piece是在token的中间,所以substr="##piece"
				start=end#start=4,此时start:end就是单词piece
			if is_bad:
				#没找到怎么办呢,用UNK代替这个token
				output_tokens.append(self.unk_token)
			else:
				output_tokens.append(sub_tokens)
		return output_tokens#['this', 'is', 'in', 'word', '##piece', 'token', '##izer']
		#output_tokens是针对整个句子的,sub_tokens是针对一个单词的,cur_substr是针对一个单词的一个子串的。
		#所以整个tokenize翻译过来就是给一个字符串句子,先用whitespace_tokenize(text)将每一个单词切分出来(注意wordpiece_tokenize的text是经过basic_tokenize后传进去的,所以不用担心特殊字符,标点符号等问题),
		#切分后是一个列表,每一个元素是一个单词,对于每一个单词,取出来单词内的所有字符(chars=list(tokens)),然后从起始位置找这个单词的子单词有没有在vocab中出现过,如果无论是单词还是子串都没有在vocab中那么就用unk_token代替,所以正如this is in wordpiece tokenizer的输出所示:
		#this is in三个单词没有被切分是因为这三个单词在vocab中均出现过,而wordpiece,tokenizer没有在vocab出现过,而word,token出现过,那么就将wordpiece切分成word+##piece,tokenizer切分成token+##izer,这里##的目的我猜是为了标明这是个切分的单词,而且是单词的尾部	

介绍完了BasicTokenizer和WordpieceTokenizer,那么就可以引入FullTokenizer了,定位161行

class FullTokenizer(object):
	def __init__(self,vocab_file,do_lower_case=True):
		self.vocab=load_vocab(vocab_file)
		self.inv_vocab={k:v for v,k in self.vocab.items()}
		#其实就是word2id和id2word
		self.basic_tokenizer=BasicTokenizer(do_lower_case=do_lower_case)
		self.wordpiece_tokenizer=WordpieceTokenizer(vocab=self.vocab)
	def tokenize(self,text):
		#传进来的text是一句话
		split_token=[]
		for token in self.basic_tokenizer(text):
			#basic_tokenizer(text)返回的是一个列表,列表中每一个元素是去掉了特殊符号,标点符号的单词
			for sub_token in self.word_tokenizer(token):
				#传给wordpiece_tokenizer的是一个单词,wordpiece_tokenizer返回的要么是原来的这个单词(说明这个单词在vocab中),要么是子串(如word,##piece,说明这个单词没有在vocab中,但是子部分在vocab中)	要么就是一个unk_token,说明这是一个bad token.
				split_tokens.append(sub_token)
		return split_token	

现在回到run_classifier.py的783行main函数

processors={"mrpc": MrpcProcessor,}#其他的就不写了
task_name=FLAGS.task_name.lower()#这个就是我们在命令行中输入的任务名称,从这可以看出来命令行中输入的字符大小写均可。
processor=processors[task_name]()#相当于processor=MrpcProcessor()
label_list=processor.get_labels()#['0','1']
#注意我们现在所在的行数是817行,下面的tokenizer带着我们跑到了另一个文件tokenization.py
#现在我们回到了tokenization.py
tokenizer=tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,do_lower_case=FLAGS.do_lower_case)#FLAGS.vocab_file就是下载的预训练文件中的vocab.txt
if FLAGS.do_train:
	train_examples=processor.get_train_examples(FLAGS.data_dir)#train_examples是一个列表,每一个元素值是InputExample的一个对象,这个对象有四个属性guid,text_a,text_b,label。所以你可以把train_examples看成是一个大的列表,长度和data_dir给出的文件的行数一样长,每一个元素值记录着文件的每一行.
train_file=os.path.join(FLAGS.output_dir,"train.tf_record")#注意train_file是你在命令终端中输入的output_dir的位置,不要把它误以为是train.txt所在的文件位置
#代码会将train.txt经过一系列的操作后以tfrecord格式存储到train_file中,模型训练时是从train_file中读取数据,所以命名为train_file

#下面定位到869行,file_based_convert_examples_to_features(),传进去的变量有train_examples.label_list,max_seq_length,tokenizer,train_file,
#这几个变量大家应该知道都是什么了max_seq_length默认是128
#接下来进入file_based_convert_examples_to_features()
#定位到479行

再来看看另外的一些函数

class PaddingInputExample:
	#就没有了,这个类的目的是用一个什么都没有的对象代替None
class InputFeatures(object):
	def __init__(self,input_ids,input_mask,segment_ids,label_id,is_real_example=True):
		#input_ids是一个句子中每一个单词对应的在vocab中的索引
		#input_mask是指传进来的input_ids中那些单词是pad的,需要mask的,
		#segment_ids是指明那些单词是第一个句子,那些单词是第二个句子
		#label_id是一个整数值,表明当前这个example的标签
		assert len(input_ids)==len(input_mask)==len(segment_ids)

def truncate_seq_pair(tokens_a,tokens_b,max_seq_length):
	#如果两个句子的长度加起来>max_seq_length,那么就将长度比较长的句子截断。(不截断句子短的是因为本来就短,再截断后整个句子就没什么信息了)
	while True:
		sentence_length=len(tokens_a)+len(tokens_b)
		if sentence_length<=max_seq_length:
			break
		if len(tokens_a)>len(tokens_b):
			tokens_a.pop()
		else:
			tokens_b.pop()#列表还有pop的操作

def convert_single_example(example_index,example,label_list,max_seq_length,tokenizer):
	#传进来的example是examples的一行,examples是InputExample的对象
	if isinstance(example,PaddingInputExample):
		#也就是说如果传进来的example是None的话,那么就会执行下面的语句
		return InputFeatures(input_ids=[0]*max_seq_length,
			input_mask=[0]*max_seq_length,segment_ids=[0]*max_seq_length,
			label_id=0,is_real_example=False)
		#这么做的目的看着挺困惑,源码中给出的解释是为了使得输入的examples的长度是batch_size的倍数,因为TPU需要固定的batch_size,
		#所以说输入的examples的最后的比batch_size少的那几个example,会在它们后面加上一些没有实际值的InputFeatures
	label_map={}
	for (i,label) in enumerate(label_list):
		label_map[label]=i
	tokens_a=tokenizer.tokenize(example.text_a)
	tokens_b=None
	if example.text_b:
		tokens_b=tokenizer.tokenize(example.text_b)
	if tokens_b:
		truncate_seq_pair(tokens_a,tokens_b,max_seq_length-3)#-3是因为如果有两个句子,那么就会有三个特殊字符[CLS]+tokens_a+[SEP]+tokens_b+[SEP]
	else:
		#只有一个句子
		if len(tokens_a)+2>max_seq_length:
			tokens_a=tokens_a[:max_seq_length-2]
	tokens=[]#tokens记录的是每一个单词
	segment_ids=[]#记录的是单词是属于哪一个句子的
	tokens.append("[CLS]")
	segment_ids.append(0)
	for token in tokens_a:
		tokens.append(token)
		segment_ids.append(0)
	tokens.append("[SEP]")
	segment_ids.append(0)
	
	if tokens_b:
		for token in tokens_b:
			tokens.append(token)
			segment.append(1)
		tokens.append("[SEP]")
		segment_ids.append(1)
	input_ids=tokenizer.convert_tokens_to_ids(tokens)#这个函数的名字就告诉我们它的作用了,将每一个token转换成对应的id
	input_mask=[1]*len(input_ids)#len(input_ids)是句子的真实长度,所以在mask的时候是不能mask这个长度下单词的,所以input_mask在0-len(input_ids)的位置都是1
	masked_length=max_seq_length-len(input_mask)
	input_ids+=[0]*masked_length
	input_mask+=[0]*masked_length
	segment_ids+=[0]*masked_length#这里注意的是segment_ids前面为0代表这个单词是第一个句子的,中间为1代表是第二个句子的,1后面还会有0,代表是pad的
	label_id=label_map[example.label]#相当于label2id[label],也就是找到当前这个example的标签类别
	return InputFeatures(input_ids,input_mask,segment_ids,
		label_id,is_real_example=True)
'''
也就是说传进来的是一个example,也就是InputExamples的一个对象,没有方法,有四个属性guid,text_a,text_b,label,
而经过convert_single_example后返回的是InputFeatures的一个对象,没有方法,有四个属性,input_ids,input_mask,segment_ids,label_id

所以可以拿一个句子对来举例子:
["example is a object of input examples","feature is a object of input features","1"]
那么经过convert_single_example后就是
[input_ids=[1,2,3,4,5,6,7,8,2,3,4,5,6,7,9,0,0,0,0,0,0,0,0,000......]
segment_ids=[0,0,0,0,0,0,0,1,1,1,1,1,1,0,000000,00..]
input_mask=[1,1,1,1,1,1,1,1,0000000000000.......0]
label_id=1]
'''

TFRecord数据文件是将数据和对应的标签统一存储的二进制格式文件,生成tfrecord文件的格式是先读取原生数据,根据原生数据生成tfrecord文件,再写回磁盘。
然后再利用API从磁盘读取tfrecord文件
一般的生成tfrecord的步骤是

tf.train.Example(features=tf.train.Features(feature={key:value}))
#其中key是你起的特征的名字,value就是特征,大致分为三种类型的特征:
'''
Int64List,用来存储int型数据
BytesList,用来存储字符串
FloatList,用来存储浮点型数据
'''
#下面的就是源码中给出的形式
tf_example=tf.train.Example(
features=tf.train.Features(feature=
{"input_ids":tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.input_ids))),
"input_mask":tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.input_mask))).
"segment_ids":tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.segment_ids))),
}))
#然后tf_example序列化,再将序列化后的文件写成tfrecord文件

下面我们来看file_based_convert_examples_to_features,这个函数的用处就是将examples写成tfrecord,定位到479行

def file_based_convert_examples_to_features(examples,label_list,max_seq_length,tokenizer,output_file):
	#examples就是InputExamples的一个对象,没有方法,有四个属性,guid,text_a,text_b,label
	#label_list在Mrpc任务中是['0','1']
	#max_seq_length 默认是128
	#output_file就是将examples转换成tfrecord后的文件输出位置
	writer=tf.python_io.TFRecordWriter(output_file)
	for (example_index,example) in enumerate(examples):
		feature=convert_single_example(example_index,example,label_list,max_seq_length,tokenizer)
		#传进去的example_index的作用是为了打印所有example的前五个,这就是当运行源码时一开始会看到一大堆的输出信息,包括input_ids,input_mask,segment_ids,你也可以在convert_single_example中去掉打印的那几行
		#convert_single_example返回的feature是就是INputFeatures的一个对象,没有方法,有几个属性包括input_ids,input_mask,segment_ids,label_id,is_real_example等
		'''接下来注意,我上面写的那个代码块和源码中492-502行意思是一样的,
			目的就是构造一个dict,dict的key是特征名字,对应的value就是由tf.train.Feature创造的特征
			然后将构造的dict作为features生成tf.train.Example的一个实例
		'''
		tf_example=tf.train.Example(features=tf.train.Features(feature=features))
		writer.write(tf_example.SerializeToString())
		#每生成一个特征feature,就把它转成tfrecord格式,序列化后写入文件
	writer.close()
	
#上面就已经生成了tfrecord文件,接下来就是读取
def file_based_input_fn_builder(input_file,max_seq_length,is_training,drop_remainder):
	#input_file就是上一个函数的output_file
	'''读取的API主要是tf.data.TFRecordDataset(file_path),返回的是一个dataset,
	这个dataset中的每一个元素就是序列化的一个tf_example,我们要把它解析回原来的类型
	(详细说明下就是由convert_single_example返回的feature是一个对象,有几个属性(此时打印feature可以看到几个列表),利用这些属性值将feature改造成一个tf.train.Feature(此时打印feature就是json格式的类型,key就是名,value就是列表),然后序列化转成tfrecord(此时打印feature就是一个二进制字符串),
	这就是生成tfrecord的过程,读取tfrecord文件再解析回原来列表型的数据类型就是相应的逆过程)
	'''
	def input_fn(params):
		dataset=tf.data.TFRecordDataset(input_file)
		dataset=dataset.apply(tf.contrib.data.map_and_batch(
			lambda record:decode_record(record,name_to_features),
			batch_size=params["batch_size"],
			drop_remainder=drop_remainder))
		return dataset
	#dataset.apply()的作用就是将dataset放到tf.contrib.data.map_and_batch()中,map_and_batch()会将dataset中的数据以batch_size的个数拿出来放到decode_record中,
	#如果最后的部分不足batch_size,drop_remainder=True的意思就是去掉这部分。
	#我们刚才说了,你从dataset中读取出来的是二进制字符串,需要把它解析回原来的列表格式才能送进网络,decode_record的作用就是解析
	def decode_record(record,name_to_features):
		example=tf.parse_single_example(record,name_to_features)
		#parse_single_example就是解析record,name_to_features的作用就是告诉函数record中的每个值原来是什么类型的
		'''
		name_to_features={
			"input_ids":tf.FixedLenFeature([seq_length],tf.int64)
		}
		'''
		#意思就是说record中name为input_ids的数据原来的格式是一个长度为seq_length的列表,类型是tf.int64
		return example#只不过源码中把所有tf.int64类型的数据都转成tf.int32的数据
	return input_fn

好的现在回到868行

train_file=os.path.join(FLAGS.output_dir,"train.tf_record")
#output_dir就是你在终端中输入的参数output_dir,并且程序执行过程中你也看到了Writting example %d of %d,对应487行
#这就是说模型在向output_dir中写入tf_record文件.
file_based_convert_examples_to_features(train_examples,label_list,max_seq_length,tokenizer,train_file)#生成tfrecord数据写到train_file中
train_input_fn=file_based_input_fn_builder(train_file,max_seq_length)#读取tfrecord文件,解析回原来的数据格式,返回的是函数
estimator.train(train_input_fn,num_train_steps)
#这就是整个数据处理的流程

回头总结一下数据处理的流程

  • 首先自己写一个MyProcessor,然后在790行那里加上名字。你的MyProcessor主要的方法有get_train_examples,get_dev_examples,get_test_examples.这三个方法返回的是InputExamples的一个对象,有四个属性,guid,这个不重要,text_a,text_b,label,后三个就是典型的Mrpc任务需要的数据,即句子1,句子2,标签。
  • 然后把get_train_examples返回的实例传入到file_based_convert_examples_to_features中生成tfrecord文件。
  • file_based_input_fn_builder会读取tfrecord文件,返回一个函数给estimator.
  • 剩下的就是细节,尤其是tokenization.py文件,里面的函数几乎要全部了解。(注意我只是介绍了数据处理的流程)

你可能感兴趣的:(BERT源码分析之数据预处理部分)