本文档介绍了如何使用BERT实现多类别文本分类任务,适合稍微了解BERT和文本分类的同学参考。
首先,在github上clone谷歌的BERT项目,或者直接下载。项目地址
然后,下载中文预训练模型,地址
tensorflow >= 1.11.0
注意:
准备数据集,包括训练集、验证集、测试集,格式相同,每行为一个类别+文本,用“\t”间隔。(如果选择其他间隔符,需要修改run_classifier.py中_read_tsv方法)。
我做的是新闻文本分类,数据格式如下:
def get_train_examples(self, data_dir): 读取训练集
def get_dev_examples(self, data_dir): 读取验证集
def get_test_examples(self, data_dir): 读取测试集
def get_labels(self, labels): 获得类别集合
def _create_examples(self, lines, set_type): 生成训练和验证样本
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"zbs": ZbsProcessor
}
if FLAGS.do_train:
train_examples = processor.get_train_examples(FLAGS.data_dir)
改为:
train_examples,train_labels,temp=processor.get_train_examples(FLAGS.data_dir)
label_list = processor.get_labels(train_labels)
if FLAGS.do_train:
data_dir:存放数据集的文件夹
bert_config_file:bert中文模型中的bert_config.json文件
task_name:processors中添加的任务名“zbs”
vocab_file:bert中文模型中的vocab.txt文件
output_dir:训练好的分类器模型的存放文件夹
init_checkpoint:bert中文模型中的bert_model.ckpt.index文件
do_train:是否训练,设置为“True”
do_eval:是否验证,设置为“True”
do_predict:是否测试,设置为“False”
可调参数:
max_seq_length:输入文本序列的最大长度,也就是每个样本的最大处理长度,多余会去掉,不够会补齐。最大值512。
train_batch_size: 训练模型求梯度时,批量处理数据集的大小。值越大,训练速度越快,内存占用越多。
eval_batch_size: 验证时,批量处理数据集的大小。同上。
predict_batch_size: 测试时,批量处理数据集的大小。同上。
learning_rate: 反向传播更新权重时,步长大小。值越大,训练速度越快。值越小,训练速度越慢,收敛速度慢,
容易过拟合。迁移学习中,一般设置较小的步长(小于2e-4)
num_train_epochs:所有样本完全训练一遍的次数。
warmup_proportion:用于warmup的训练集的比例。
save_checkpoints_steps:检查点的保存频率。
如果在文件中已经设置后参数,直接运行即可。
也可以在执行.py文件时,传入参数,例如:
python zbs_classifier.py --data_dir=/home/hls/bert_zbs_data/data2c-11
--init_checkpoint=/home/hls/bert_zbs_data/data2c-11/out1/model.ckpt-3616.index
--output_dir=/home/hls/bert_zbs_data/data2c-11/out3
--max_seq_length=256
--learning_rate=2e-5
--num_train_epochs=50
class ZbsProcessor(DataProcessor):
def get_train_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self, labels):
return set(labels)
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
labels = []
labels_test = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
if set_type == "test":
label = "台湾" #这里要设置成数据集中一个真实的类别
else:
label = tokenization.convert_to_unicode(line[0])
labels.append(label)
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples, labels, labels_test
class InputExample(object):
class InputFeatures(object):
class DataProcessor(object):
class ZbsProcessor(DataProcessor):
def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer):
将单条训练数据,由 class InputExample 结构转换成 class InputFeature 的结构
def file_based_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_file):
遍历训练样本,将其转换成InputFeatures特征,并保存到train.TFRecord文件中。调用convert_single_example()方法实现单条数据转换。
def file_based_input_fn_builder(input_file, seq_length, is_training,
drop_remainder):
根据保存的训练文件train.TFRecord,生成tf.data.TFRecordDataset用于提供给Estimator来训练。
def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, use_one_hot_embeddings):
返回tf.contrib.tpu.TPUEstimatorSpec对象。