官方BERT代码 pre-training 一步步来


1. 首先需要处理pre-training所需要的数据


python create_pretraining_data.py \
  --input_file=./sample_text.txt \
  --output_file=/tmp/tf_examples.tfrecord \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --do_lower_case=True \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --masked_lm_prob=0.15 \
  --random_seed=12345 \
  • max_predictions_per_seq:每个序列里最大的masked lm predictions。建议设置为max_seq_length*masked_lm_prob(这个脚本不会自动设置)

文本输入格式:一行一句话(对于next sentence prediction这很重要),不同文档间用空行分隔。例如源码中附带的sample_text.txt示例:

This text is included to make sure Unicode is handled properly: 力加勝北区ᴵᴺᵀᵃছজটডণত
Text should be one-sentence-per-line, with empty lines between documents.
This sample text is public domain and was randomly selected from Project Guttenberg.

The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.



2. 接下来进行pre-training


python run_pretraining.py \
  --input_file=/tmp/tf_examples.tfrecord \
  --output_dir=/tmp/pretraining_output \
  --do_train=True \
  --do_eval=True \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --train_batch_size=32 \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --num_train_steps=20 \
  --num_warmup_steps=10 \
  • 如果你是从头开始pre-training,需要把include init_checkpoint去掉
  • 模型配置(包括vocab size)在bert_config_file中设置
  • num_train_steps在现实中一般要设置10000以上
  • max_seq_length和max_predictions_per_seq要和create_pretraining_data的参数一样




 if FLAGS.horovod and len(input_files) < hvd.size():
      raise ValueError("Input Files must be sharded")
  if FLAGS.amp and FLAGS.manual_fp16:
      raise ValueError("AMP and Manual Mixed Precision Training are both activated! Error")


分片之后,对每一片都运行一遍create_pretraining_data.py,得到对应的tf_examples.tfrecord_X, X是你随意给分片做的编号 
