Bert源码解析--训练集生成

这一部分的源码主要实现在create_pretraining_data.py和tokenization.py两个脚本里。

先介绍主要部分:create_pretraining_data.py

这里小标1,2用的太多了,为了方便区分,我用了不同颜色(绿)的小标表示,同一个颜色是一个部分的;脚本中用到的函数,我用紫色的进行了标识。

源码地址:https://github.com/google-research/bert/blob/master/create_pretraining_data.py

一、create_pretraining_data.py:原始文件转换为训练数据格式

输入:词典,原始文本(空行分割不同的文章,一行一句话

输出:训练数据

作用:生成训练数据,句子对组合、单词mask等

步骤

         1. 加载词典,加载原始数据(即:输入部分)

         2. 主要部分 -- 针对文章进行处理,得到训练数据

         (1)按行读取原始文本。         

         (2)在经过 tokenization.py处理后,原始文本会转换为[[[first doc first sentence],[first doc second sentence],[first doc third sentence]],[[second doc first sentence],[]],....] 这样的结构。

           注意:tokenization.py (对原始数据进行格式上的处理,eg:Unicode转换、空格切分、复合词切分等,部分二将介绍)

         (3)对文章顺序进行打乱。

          注意:以上步骤在create_training_instances中,并调用create_instances_from_document返回数据类型instances

         (4) create_instances_from_document -- 针对每篇文章进行处理,生成训练数据(此过程重复dupe_factor次,默认为10)。

create_instances_from_document 过程详解

(1)生成句子对:

这里需要先介绍两个参数:target_seq_length 、 max_seq_length

         max_seq_length 初始化:

flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")

         target_seq_length (单条样本最长值)初始化:

max_num_tokens = max_seq_length - 3
target_seq_length = max_num_tokens

接下来是生成句子对的具体过程

        a.  从第一个句子循环到最后一个句子,收集segment(句子)到current_chunk(一个列表)中,当收集到的总句子长度(这个长度相当于句子中单词的个数)>= target_seq_length时,构造A+B。

        那么是怎么构造A+B的呢?

        随机截取current_chunk中的 某一个位置a_end(a_end是按照整个句子进行选择的),则 [0,a_end] 作为句子A;句子B的选取是有 next 或者 not next 两种概率得到的,每种概率为 50% 。当选择为next时, B =[a_end,:],否则,B是随机选取的其他文章内容。

        b. 添加分隔符,同时生成segment_ids(A句为0,B句为1)

(2)随机mask:

对句子中的单词做随机mask(调用create_masked_lm_predictions),随机取num_to_predict个单词做mask,其中 0.8 的概率标记为[MASK],0.1的概率标记为原始单词,0.1的概率标记为随机单词。

(tokens, masked_lm_positions,
         masked_lm_labels) = create_masked_lm_predictions( 
         tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)

输入和返回结果举例:

input tokens  ="The man went to the store . He bought a gallon of milk "
ouput tokens ="The man went to the [mask] . He [mask] a gallon of milk"
output masked_lm_positions = [5, 8, 10, 11]
output masked_lm_labels = [store, bought, gallon, ice]

(3)封装

instance = TrainingInstance(
            tokens=tokens,
            segment_ids=segment_ids,
            is_random_next=is_random_next,
            masked_lm_positions=masked_lm_positions,
            masked_lm_labels=masked_lm_labels)
instances.append(instance)

segment_ids(A句为0,B句为1)

 3. 训练数据序列化,存入文件---write_instance_to_example_files

def write_instance_to_example_files(instances, tokenizer, max_seq_length,
                                    max_predictions_per_seq, output_files):

输入:

        instances : 见上述

       tokenizer:tokenization.py 得到

       max_seq_length:见上述

       max_predictions_per_seq:每个序列最大mask长度

       output_files:将instances写入output_files

这里要说的有四点:

(1)ID化

对单词进行ID化

input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)

需要进行ID化的还有之后的masked_lm_ids

masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)

(2)mask权重设置:方便后续loss计算

masked_lm_weights = [1.0] * len(masked_lm_ids)  

(3)按照max_predictions_per_seq和max_seq_length长度,分别对相应样本补齐,padding为0

while len(input_ids) < max_seq_length:
      input_ids.append(0)
      input_mask.append(0)
      segment_ids.append(0)

注意:input_mask是样本中有效词句的标识,之后需要用作attention视野的约束。

while len(masked_lm_positions) < max_predictions_per_seq:
      masked_lm_positions.append(0)
      masked_lm_ids.append(0)
      masked_lm_weights.append(0.0)

(4)把instance的is_random_next转化成变量next_sentence_label保存。

next_sentence_label = 1 if instance.is_random_next else 0

实例:

INFO:tensorflow:*** Example ***
INFO:tensorflow:tokens: [CLS] this was nearly opposite . [SEP] at last dunes reached the quay at [MASK] opposite end of [MASK] street [MASK] and there burst on [MASK] ##am ##mon [MASK] s [MASK] eyes a vast semi [MASK] ##rcle of blue sea , ring ##ed with palaces and towers . [MASK] stopped in ##vo ##lun ##tar [MASK] ; and his little guide [MASK] also , and looked ask ##ance at the young monk , [MASK] watch the effect which that [MASK] panorama should produce on him . [SEP]
INFO:tensorflow:input_ids: 101 2023 2001 3053 4500 1012 102 2012 2197 17746 2584 1996 21048 2012 103 4500 2203 1997 103 2395 103 1998 2045 6532 2006 103 3286 8202 103 1055 103 2159 1037 6565 4100 103 21769 1997 2630 2712 1010 3614 2098 2007 22763 1998 7626 1012 103 3030 1999 6767 26896 7559 103 1025 1998 2010 2210 5009 103 2036 1010 1998 2246 3198 6651 2012 1996 2402 8284 1010 103 3422 1996 3466 2029 2008 103 23652 2323 3965 2006 2032 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:tensorflow:masked_lm_positions: 9 14 18 20 25 28 30 35 48 54 60 72 78 0 0 0 0 0 0 0
INFO:tensorflow:masked_lm_ids: 2027 1996 1996 1025 6316 1005 22741 6895 2002 6588 3030 2000 2882 0 0 0 0 0 0 0
INFO:tensorflow:masked_lm_weights: 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
INFO:tensorflow:next_sentence_labels: 1

 

你可能感兴趣的:(bert)