BERT文本分类使用指南

本文档介绍了如何使用BERT实现多类别文本分类任务,适合稍微了解BERT和文本分类的同学参考。

(一) 下载

首先,在github上clone谷歌的BERT项目,或者直接下载。项目地址

然后,下载中文预训练模型,地址

(二) 环境准备

tensorflow >= 1.11.0

注意:

  1. 在GPU上运行Tensorflow,需要CUDA版本和Tensorflow版本的对应。比如Tensorflow-1.11.0最高只能使用9.0版本的CUDA,否则加载时会出现找不到libcublas.so的错误。
  2. 安装TensorFlow时,如果出现无法卸载enum34的错误,可以用pip install *** --ignore_installed enum34命令先跳过。

(三) 数据准备

准备数据集,包括训练集、验证集、测试集,格式相同,每行为一个类别+文本,用“\t”间隔。(如果选择其他间隔符,需要修改run_classifier.py中_read_tsv方法)。

我做的是新闻文本分类,数据格式如下:

在这里插入图片描述

(四) 修改run_classifier.py文件

  1. 添加处理数据集的类,class ZbsProcessor(DataProcessor),分别实现以下方法:
def get_train_examples(self, data_dir): 读取训练集
def get_dev_examples(self, data_dir): 读取验证集
def get_test_examples(self, data_dir): 读取测试集
def get_labels(self, labels): 获得类别集合
def _create_examples(self, lines, set_type): 生成训练和验证样本
  1. 修改main函数。在第744行,将ZbsProcessor添加到processors中
processors = {
    "cola": ColaProcessor,
    "mnli": MnliProcessor,
    "mrpc": MrpcProcessor,
    "xnli": XnliProcessor,
    "zbs": ZbsProcessor
} 
  1. 原代码中,先判断是否train,然后获取训练样本,但是后面需要所有类别,所以需要改成先获取所有类别,然后判断判断是否train。即代码:
if FLAGS.do_train:
  train_examples = processor.get_train_examples(FLAGS.data_dir)

改为:

train_examples,train_labels,temp=processor.get_train_examples(FLAGS.data_dir)
label_list = processor.get_labels(train_labels)
if FLAGS.do_train:
  1. 修改运行参数。可以直接在代码里修改,也可以执行.py文件时提供参数。参数意义:
data_dir:存放数据集的文件夹
bert_config_file:bert中文模型中的bert_config.json文件
task_name:processors中添加的任务名“zbs”
vocab_file:bert中文模型中的vocab.txt文件
output_dir:训练好的分类器模型的存放文件夹
init_checkpoint:bert中文模型中的bert_model.ckpt.index文件
do_train:是否训练,设置为“True”
do_eval:是否验证,设置为“True”
do_predict:是否测试,设置为“False

可调参数:

max_seq_length:输入文本序列的最大长度,也就是每个样本的最大处理长度,多余会去掉,不够会补齐。最大值512。
train_batch_size: 训练模型求梯度时,批量处理数据集的大小。值越大,训练速度越快,内存占用越多。
eval_batch_size: 验证时,批量处理数据集的大小。同上。
predict_batch_size: 测试时,批量处理数据集的大小。同上。
learning_rate: 反向传播更新权重时,步长大小。值越大,训练速度越快。值越小,训练速度越慢,收敛速度慢,
    容易过拟合。迁移学习中,一般设置较小的步长(小于2e-4)
num_train_epochs:所有样本完全训练一遍的次数。
warmup_proportion:用于warmup的训练集的比例。
save_checkpoints_steps:检查点的保存频率。

(五) 运行。

如果在文件中已经设置后参数,直接运行即可。

也可以在执行.py文件时,传入参数,例如:

python zbs_classifier.py --data_dir=/home/hls/bert_zbs_data/data2c-11 
--init_checkpoint=/home/hls/bert_zbs_data/data2c-11/out1/model.ckpt-3616.index
--output_dir=/home/hls/bert_zbs_data/data2c-11/out3 
--max_seq_length=256 
--learning_rate=2e-5 
--num_train_epochs=50

(六) 附录代码

class ZbsProcessor(DataProcessor):
  def get_train_examples(self, data_dir):
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

  def get_dev_examples(self, data_dir):
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

  def get_test_examples(self, data_dir):
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

  def get_labels(self, labels):
    return set(labels)

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    labels = []
    labels_test = []
    for (i, line) in enumerate(lines):
      if i == 0:
        continue
      guid = "%s-%s" % (set_type, i)
      text_a = tokenization.convert_to_unicode(line[1])
      if set_type == "test":
        label = "台湾" #这里要设置成数据集中一个真实的类别
      else:
        label = tokenization.convert_to_unicode(line[0])
      labels.append(label)  
      examples.append(  
          InputExample(guid=guid, text_a=text_a, text_b=None, label=label))   
    return examples, labels, labels_test

(七)代码结构

class InputExample(object):

class InputFeatures(object):

class DataProcessor(object):

class ZbsProcessor(DataProcessor):

def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer):
 将单条训练数据,由 class InputExample 结构转换成 class InputFeature 的结构

def file_based_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_file):
 遍历训练样本,将其转换成InputFeatures特征,并保存到train.TFRecord文件中。调用convert_single_example()方法实现单条数据转换。

def file_based_input_fn_builder(input_file, seq_length, is_training,
drop_remainder):
 根据保存的训练文件train.TFRecord,生成tf.data.TFRecordDataset用于提供给Estimator来训练。

def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, use_one_hot_embeddings):
 返回tf.contrib.tpu.TPUEstimatorSpec对象。

BERT文本分类使用指南_第1张图片

(八) 有趣的ISSUES

  1. 如何在训练时输出loss
    logging_hook = tf.train.LoggingTensorHook({“loss”: total_loss}, every_n_iter=10)
    output_spec = tf.contrib.tpu.TPUEstimatorSpec(
    mode=mode,
    loss=total_loss,
    train_op=train_op,
    training_hooks=[logging_hook],
    scaffold_fn=scaffold_fn)

你可能感兴趣的:(深度学习算法)