
pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现实(在Google Cloud TPU v2 上训练BERT-Base要花费近500刀,耗时达到两周。在GPU上可想而知只会更贵),但是学习bert的预训练方法可以为我们弄懂整个bert的运行流程提供莫大的帮助。预训练涉及到的模块有点多,所以这也将会是一篇长文,在能简略的地方我尽量简略,还是那句话,我的文章只能是起到一个导读的作用,如果想摸清里面的各种细节还是要自己把源码过一遍的。





class BasicTokenizer(object): """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" def __init__(self, do_lower_case=True): self.do_lower_case = do_lower_case def tokenize(self, text): """Tokenizes a piece of text.""" text = convert_to_unicode(text) text = self._clean_text(text) text = self._tokenize_chinese_chars(text) orig_tokens = whitespace_tokenize(text) split_tokens = [] for token in orig_tokens: if self.do_lower_case: token = token.lower() token = self._run_strip_accents(token) split_tokens.extend(self._run_split_on_punc(token)) output_tokens = whitespace_tokenize(" ".join(split_tokens)) return output_tokens def _run_strip_accents(self, text): """Strips accents from a piece of text.""" text = unicodedata.normalize("NFD", text) output = [] for char in text: cat = unicodedata.category(char) if cat == "Mn": continue output.append(char) return "".join(output) def _run_split_on_punc(self, text): """Splits punctuation on a piece of text.""" chars = list(text) i = 0 start_new_word = True output = [] while i < len(chars): char = chars[i] if _is_punctuation(char): output.append([char]) start_new_word = True else: if start_new_word: output.append([]) start_new_word = False output[-1].append(char) i += 1 return ["".join(x) for x in output] def _tokenize_chinese_chars(self, text): """Adds whitespace around any CJK character.""" output = [] for char in text: cp = ord(char) if self._is_chinese_char(cp): output.append(" ") output.append(char) output.append(" ") else: output.append(char) return "".join(output) def _is_chinese_char(self, cp): """Checks whether CP is the codepoint of a CJK character.""" if ((cp >= 0x4E00 and cp <= 0x9FFF) or # (cp >= 0x3400 and cp <= 0x4DBF) or # (cp >= 0x20000 and cp <= 0x2A6DF) or # (cp >= 0x2A700 and cp <= 0x2B73F) or # (cp >= 0x2B740 and cp <= 0x2B81F) or # (cp >= 0x2B820 and cp <= 0x2CEAF) or (cp >= 0xF900 and cp <= 0xFAFF) or # (cp >= 0x2F800 and cp <= 0x2FA1F)): # return True return False def _clean_text(self, text): """Performs invalid character removal and whitespace cleanup on text.""" output = [] for char in text: cp = ord(char) if cp == 0 or cp == 0xfffd or _is_control(char): continue if _is_whitespace(char): output.append(" ") else: output.append(char) return "".join(output) 



class WordpieceTokenizer(object): """Runs WordPiece tokenziation.""" def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): self.vocab = vocab self.unk_token = unk_token self.max_input_chars_per_word = max_input_chars_per_word def tokenize(self, text): text = convert_to_unicode(text) output_tokens = [] for token in whitespace_tokenize(text): chars = list(token) if len(chars) > self.max_input_chars_per_word: output_tokens.append(self.unk_token) continue is_bad = False start = 0 sub_tokens = [] while start < len(chars): end = len(chars) cur_substr = None while start < end: substr = "".join(chars[start:end]) if start > 0: substr = "##" + substr if substr in self.vocab: cur_substr = substr break end -= 1 if cur_substr is None: is_bad = True break sub_tokens.append(cur_substr) start = end if is_bad: output_tokens.append(self.unk_token) else: output_tokens.extend(sub_tokens) return output_tokens 

WordpieceTokenizer的目的是将合成词分解成类似词根一样的词片。例如将"unwanted"分解成["un", "##want", "##ed"]这么做的目的是防止因为词的过于生僻没有被收录进词典最后只能以[UNK]代替的局面,因为英语当中这样的合成词非常多,词典不可能全部收录。


class FullTokenizer(object): """Runs end-to-end tokenziation.""" def __init__(self, vocab_file, do_lower_case=True): self.vocab = load_vocab(vocab_file) self.inv_vocab = {v: k for k, v in self.vocab.items()} self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) def tokenize(self, text): split_tokens = [] for token in self.basic_tokenizer.tokenize(text): for sub_token in self.wordpiece_tokenizer.tokenize(token): split_tokens.append(sub_token) return split_tokens def convert_tokens_to_ids(self, tokens): return convert_by_vocab(self.vocab, tokens) def convert_ids_to_tokens(self, ids): return convert_by_vocab(self.inv_vocab, ids) 




flags.DEFINE_string("input_file", None,
                    "Input raw text file (or comma-separated list of files).")
    "output_file", None, "Output TF example file (or comma-separated list of files).") flags.DEFINE_string("vocab_file", None, "The vocabulary file that the BERT model was trained on.") flags.DEFINE_bool( "do_lower_case", True, "Whether to lower case the input text. Should be True for uncased " "models and False for cased models.") flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") flags.DEFINE_integer("max_predictions_per_seq", 20, "Maximum number of masked LM predictions per sequence.") flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") flags.DEFINE_integer( "dupe_factor", 10, "Number of times to duplicate the input data (with different masks).") flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") flags.DEFINE_float( "short_seq_prob", 0.1, "Probability of creating sequences which are shorter than the " "maximum length.") 






def main(_): tf.logging.set_verbosity(tf.logging.INFO) tokenizer = tokenization.FullTokenizer( vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) input_files = [] for input_pattern in FLAGS.input_file.split(","): input_files.extend(tf.gfile.Glob(input_pattern))"*** Reading from input files ***") for input_file in input_files:" %s", input_file) rng = random.Random(FLAGS.random_seed) instances = create_training_instances( input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, rng) output_files = FLAGS.output_file.split(",")"*** Writing to output files ***") for output_file in output_files:" %s", output_file) write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, FLAGS.max_predictions_per_seq, output_files) 

从入口开始看,步骤很简单:1)构造tokenizer ;2)构造instances ;3)保存instances


def create_training_instances(input_files, tokenizer, max_seq_length, dupe_factor, short_seq_prob, masked_lm_prob, max_predictions_per_seq, rng): """Create `TrainingInstance`s from raw text.""" all_documents = [[]] for input_file in input_files: with tf.gfile.GFile(input_file, "r") as reader: while True: line = tokenization.convert_to_unicode(reader.readline()) if not line: break line = line.strip() # Empty lines are used as document delimiters if not line: all_documents.append([]) tokens = tokenizer.tokenize(line) if tokens: all_documents[-1].append(tokens) # Remove empty documents all_documents = [x for x in all_documents if x] rng.shuffle(all_documents) vocab_words = list(tokenizer.vocab.keys()) instances = [] for _ in range(dupe_factor): for document_index in range(len(all_documents)): instances.extend( create_instances_from_document( all_documents, document_index, max_seq_length, short_seq_prob, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) rng.shuffle(instances) return instances 


def create_instances_from_document( all_documents, document_index, max_seq_length, short_seq_prob, masked_lm_prob, max_predictions_per_seq, vocab_words, rng): """Creates `TrainingInstance`s for a single document.""" document = all_documents[document_index] # Account for [CLS], [SEP], [SEP] max_num_tokens = max_seq_length - 3 target_seq_length = max_num_tokens if rng.random() < short_seq_prob: target_seq_length = rng.randint(2, max_num_tokens) instances = [] current_chunk = [] current_length = 0 i = 0 while i < len(document): segment = document[i] current_chunk.append(segment) current_length += len(segment) if i == len(document) - 1 or current_length >= target_seq_length: if current_chunk: # `a_end` is how many segments from `current_chunk` go into the `A` # (first) sentence. a_end = 1 if len(current_chunk) >= 2: a_end = rng.randint(1, len(current_chunk) - 1) tokens_a = [] for j in range(a_end): tokens_a.extend(current_chunk[j]) tokens_b = [] # Random next is_random_next = False if len(current_chunk) == 1 or rng.random() < 0.5: is_random_next = True target_b_length = target_seq_length - len(tokens_a) for _ in range(10): random_document_index = rng.randint(0, len(all_documents) - 1) if random_document_index != document_index: break random_document = all_documents[random_document_index] random_start = rng.randint(0, len(random_document) - 1) for j in range(random_start, len(random_document)): tokens_b.extend(random_document[j]) if len(tokens_b) >= target_b_length: break num_unused_segments = len(current_chunk) - a_end i -= num_unused_segments # Actual next else: is_random_next = False for j in range(a_end, len(current_chunk)): tokens_b.extend(current_chunk[j]) truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) assert len(tokens_a) >= 1 assert len(tokens_b) >= 1 tokens = [] segment_ids = [] tokens.append("[CLS]") segment_ids.append(0) for token in tokens_a: tokens.append(token) segment_ids.append(0) tokens.append("[SEP]") segment_ids.append(0) for token in tokens_b: tokens.append(token) segment_ids.append(1) tokens.append("[SEP]") segment_ids.append(1) (tokens, masked_lm_positions, masked_lm_labels) = create_masked_lm_predictions( tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) instance = TrainingInstance( tokens=tokens, segment_ids=segment_ids, is_random_next=is_random_next, masked_lm_positions=masked_lm_positions, masked_lm_labels=masked_lm_labels) instances.append(instance) current_chunk = [] current_length = 0 i += 1 return instances 


instance = TrainingInstance(

1)一个instance 包含一个tokens,实际上就是输入的词序列;该序列表现形式为:


A=[token_0, token_1, ...,token_i]
B=[token_i+1, token_i+2, ...,token_n-1]

2<= n < max_seq_length - 3 (in short_seq_prob)
n=max_seq_length - 3 (in 1-short_seq_prob)

token 最后表现形式如下图所示:


segment_ids 指的形式为[0,0,0...1,1,111] 0的个数为i+1个,1的个数为max_seq_length - (i+1)






def write_instance_to_example_files(instances, tokenizer, max_seq_length, max_predictions_per_seq, output_files): """Create TF example files from `TrainingInstance`s.""" writers = [] for output_file in output_files: writers.append(tf.python_io.TFRecordWriter(output_file)) writer_index = 0 total_written = 0 for (inst_index, instance) in enumerate(instances): input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) input_mask = [1] * len(input_ids) segment_ids = list(instance.segment_ids) assert len(input_ids) <= max_seq_length while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length masked_lm_positions = list(instance.masked_lm_positions) masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) masked_lm_weights = [1.0] * len(masked_lm_ids) while len(masked_lm_positions) < max_predictions_per_seq: masked_lm_positions.append(0) masked_lm_ids.append(0) masked_lm_weights.append(0.0) next_sentence_label = 1 if instance.is_random_next else 0 features = collections.OrderedDict() features["input_ids"] = create_int_feature(input_ids) features["input_mask"] = create_int_feature(input_mask) features["segment_ids"] = create_int_feature(segment_ids) features["masked_lm_positions"] = create_int_feature(masked_lm_positions) features["masked_lm_ids"] = create_int_feature(masked_lm_ids) features["masked_lm_weights"] = create_float_feature(masked_lm_weights) features["next_sentence_labels"] = create_int_feature([next_sentence_label]) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) writers[writer_index].write(tf_example.SerializeToString()) writer_index = (writer_index + 1) % len(writers) total_written += 1 if inst_index < 20:"*** Example ***")"tokens: %s" % " ".join( [tokenization.printable_text(x) for x in instance.tokens])) for feature_name in features.keys(): feature = features[feature_name] values = [] if feature.int64_list.value: values = feature.int64_list.value elif feature.float_list.value: values = feature.float_list.value "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) for writer in writers: writer.close()"Wrote %d total instances", total_written) 


while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) 

2) 把instance的is_random_next转化成变量next_sentence_label保存。


python3   --input_file=/tmp/zh_test.txt   --output_file=/tmp/output.txt   --vocab_file=$BERT_ZH_DIR/vocab.txt






INFO:tensorflow:*** Example ***
INFO:tensorflow:tokens: [CLS] i 觉 得 [UNK] u [MASK] 非 [MASK] 位 i 风 格 较 ##by 哦 个 驅 色 哦 i 多 发 [MASK] 个 v 二 哥 i 文 件 哦 i 怪 [MASK] 决 斗 盘 可 加 热 管 [MASK] u [MASK] [MASK] 文 集 狗 哥 [SEP] [MASK] [UNK] 奇 偶 均 衡 能 否 v 不 。 极 [MASK] 疯 狂 减 肥 的 人 能 否 打 开 v 高 科 技 就 而 [MASK] 就 [UNK] 哦 冏 结 构 i 恶 如 桂 萼 黑 人 牙 膏 [UNK] u 我 也 【 发 票 未 开 [MASK] 俄 日 [MASK] 件 二 我 就 佛 i 额 [MASK] 阶 [MASK] 感 v [MASK] 我 为 [MASK] 军 方 [SEP]
INFO:tensorflow:input_ids: 101 151 6230 2533 100 163 103 7478 103 855 151 7599 3419 6772 8684 1521 702 7705 5682 1521 151 1914 1355 103 702 164 753 1520 151 3152 816 1521 151 2597 103 1104 3159 4669 1377 1217 4178 5052 103 163 103 103 3152 7415 4318 1520 102 103 100 1936 981 1772 6130 5543 1415 164 679 511 3353 103 4556 4312 1121 5503 4638 782 5543 1415 2802 2458 164 7770 4906 2825 2218 5445 103 2218 100 1521 1087 5310 3354 151 2626 1963 3424 5861 7946 782 4280 5601 100 163 2769 738 523 1355 4873 3313 2458 103 915 3189 103 816 753 2769 2218 867 151 7583 103 7348 103 2697 164 103 2769 711 103 1092 3175 102
INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
INFO:tensorflow:masked_lm_positions: 6 8 14 17 23 34 42 44 45 46 51 63 80 105 108 116 118 121 124 0
INFO:tensorflow:masked_lm_ids: 5445 1392 711 5106 1126 1077 100 702 782 3152 2533 2428 1400 163 7353 1912 5277 8024 862 0
INFO:tensorflow:masked_lm_weights: 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 0.0
INFO:tensorflow:next_sentence_labels: 1






    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    masked_lm_positions = features["masked_lm_positions"]
    masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] next_sentence_labels = features["next_sentence_labels"] model = modeling.BertModel( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings) 

其中input_ids、input_mask 、segment_ids 作为X,剩下的masked_lm_positions、masked_lm_ids 、masked_lm_weights 、next_sentence_labels 共同作为Y

2、 loss

     masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
         bert_config, model.get_sequence_output(), model.get_embedding_table(),
         masked_lm_positions, masked_lm_ids, masked_lm_weights)

    (next_sentence_loss, next_sentence_example_loss,
     next_sentence_log_probs) = get_next_sentence_output(
         bert_config, model.get_pooled_output(), next_sentence_labels)

    total_loss = masked_lm_loss + next_sentence_loss

可以看到loss 分别由masked_lm_loss和next_sentence_loss组成,masked_lm_loss针对的是语言模型对MASK起来的标签的预测,即上下文语境预测当前词;而next_sentence_loss是对于句子关系的预测。前者在迁移学习中可以用于标注类任务(分词、NER等),后者可以用于句子关系任务(QA、自然语言推理等)。

需要多说一句的是,masked_lm_loss,用到了模型的sequence_output和embedding_table,这是因为对多个MASK的标签进行预测是一个标注问题,所以需要获取最后一层的整个sequence,而embedding_table用来反embedding,这样就映射到token的学习了。而next_sentence_loss用到的是pooled_output,对应的是第一个token [CLS],它一般用于分类任务的学习。






Bert系列(五)——中文分词实践 F1 97.8%(附代码)

2.BERT: Pre-training of Deep Bidirectional Transformers for
Language Understanding

