BERT本质上是一个两段式的NLP模型。第一个阶段叫做:Pre-training,跟WordEmbedding类似,利用现有无标记的语料训练一个语言模型。第二个阶段叫做:Fine-tuning,利用预训练好的语言模型,完成具体的NLP下游任务。pre-training的训练成本很大,一般直接使用google训练好的模型,而fine-tuning成本相对较少,本文介绍如何进行fine-tuning,对应的程序为run_classifier.py
,如下所示,从主函数开始
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
# 新增类,用于处理数据,加载训练数据
processors = { # 【2】新增类
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"mypro": MyProcessor
}
tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
FLAGS.init_checkpoint)
if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
raise ValueError(
"At least one of `do_train`, `do_eval` or `do_predict' must be True.")
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.max_seq_length > bert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(FLAGS.max_seq_length, bert_config.max_position_embeddings))
tf.gfile.MakeDirs(FLAGS.output_dir)
# 读取task_name, task_name实际上就是选择或自定义的processor,若自定义需要加入processors字典中
task_name = FLAGS.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
# 获取数据处理方法
processor = processors[task_name]()
# 获取标签
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
....
....
从程序中可以看出需要自定义一个类来处理原始数据,用于训练, 并将该类加入processors字典当中,数据处理类可以参考已经存在的类。此处我自定义一个MyProcessor
类来加载原始数据, 代码如下
class MyProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def __init__(self):
# 【0】 设置语言
self.language = "zh"
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1", "2", "3", "4", "5", "6", "7", "8"] # 【1】设置标签label
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1]) # 【2】text_a/text_b根据实际语料结构更改
text_b = None
if set_type == "test":
label = "-1"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
该类中主要定义几个函数分别进行获取训练(测试、验证)样本(get_train_examples
),获取标签值(get_labels
),同时也能发现,在data_dir中我们需要将数据处理成.tsv格式,训练集、开发集和测试集分别是train.tsv, dev.tsv, test.tsv,这里我们暂时只使用train.tsv和dev.tsv。另外,label在get_labels()设定,如果是二分类,则将label设定为[“0”,”1”],同时_create_examples()中,给定了如何获取guid以及如何给text_a, text_b和label赋值。
对于这个fine-tuning过程,我们要做的只是:
.tsv文件格式如下:标签+table+句子
0 i love china
1 a beautiful day
0 i love shanghai
更详细介绍fine_tuning过程及运行代码见orangerfun github
BERT源码