1. BERT简介
BERT的全称为Bidirectional Encoder Representation from Transformers,是一个预训练的语言表征模型。它强调了不再像以往一样采用传统的单向语言模型或者把两个单向语言模型进行浅层拼接的方法进行预训练,而是采用新的masked language model(MLM),以致能生成深度的双向语言表征。BERT论文发表时提及在11个NLP(Natural Language Processing,自然语言处理)任务中获得了新的state-of-the-art的结果。
2. 环境配置
- Ubuntu16.04
- Anaconda3
- python >= 3.6
- tensorflow >= 1.12.0
- pandas
先安装conda
# 查看conda环境
conda info -e
通过conda创建一个新的环境bert,切换到bert环境
# 切换到bert环境
conda activate bert
3. ChnSentiCorp数据集
我们选取ChnSentiCorp数据集,里面包含7000 多条酒店评论数据,5000 多条正向评论,2000 多条负向评论,这些评论数据有两个字段:label, review。
数据字段:
label:1表示正向评论,0表示负向评论
review:评论内容
数据地址是:https://raw.githubusercontent.com/SophonPlus/ChineseNlpCorpus/master/datasets/ChnSentiCorp_htl_all/ChnSentiCorp_htl_all.csv
新建一个脚本split_data.py,拆分成训练集train.csv,开发集dev.csv,测试集test.csv,比例8:1:1。
import pandas as pd
df = pd.read_csv('ChnSentiCorp_htl_all.csv', dtype=str)
df = df.dropna()
df = df.applymap(lambda x: str(x).strip())
df = df.sample(frac=1).reset_index(drop=True)
# split train:dev:test as 8:1:1
train_df = df.iloc[:6212]
dev_df = df.iloc[6212:6989]
test_df = df.iloc[6989:]
train_df.to_csv('train.csv', sep=',', index=False)
dev_df.to_csv('dev.csv', sep=',', index=False)
test_df.to_csv('test.csv', sep=',', index=False)
脚本执行完成后:
├── ChnSentiCorp_htl_all.csv
├── dev.csv
├── split_data.py
├── test.csv
└── train.csv
4. 下载BERT源码和预训练模型
- 下载BERT源码
https://github.com/google-research/bert/
git clone https://github.com/google-research/bert.git
- 下载BERT中文预训练模型
下载地址: https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
解压到自定义目录下
├── bert_config.json
├── bert_model.ckpt.data-00000-of-00001
├── bert_model.ckpt.index
├── bert_model.ckpt.meta
└── vocab.txt
5. 修改代码
在run_classifier.py文件中有一个基类DataProcessor类,其代码如下:
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
在这个基类中定义了一个读取文件的静态方法_read_tsv,四个分别获取训练集,验证集,测试集和标签的方法。在run_classsifier.py文件中我们可以看到,google对于一些公开数据集已经写了一些processor,如XnliProcessor,MnliProcessor,MrpcProcessor和ColaProcessor。这给我们提供了一个很好的示例,指导我们如何针对自己的数据集来写processor。接下来我们要定义自己的数据处理的类,我们将新增的类命名为SentimentProcessor。
class SentimentProcessor(DataProcessor):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_csv(
os.path.join(data_dir, "train.csv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_csv(
os.path.join(data_dir, "dev.csv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "dev-%d" % (i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
lines = self._read_csv(
os.path.join(data_dir, "test.csv"))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % (i)
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def get_labels(self):
"""See base class."""
return ["0", "1"]
@classmethod
def _read_csv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.gfile.Open(input_file, "r") as f:
reader = csv.reader(f, delimiter=",", quotechar=None)
lines = []
for line in reader:
lines.append(line)
return lines
在processors中增加SentimentProcessor
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"senti": SentimentProcessor,
}
6. 训练BERT模型
新建一个脚本文件train.sh,内容如下,日志文件输出到train.log,通过 tail -f train.log
查看,通过nvidia-smi
命令查看GPU状态。
参数说明:
data_dir: 训练数据的地址
task_name: processor的名字
vocab_file: 字典地址,用默认提供的就可以了,当然也可以自定义
bert_config_file: 配置文件
output_dir: 模型的输出地址
do_train: 是否做fine-tuning,默认为false,如果为true必须重写获取训练集的方法
do_eval: 是否运行验证集,默认为false,如果为true必须重写获取验证集的方法
do_predict: 是否做预测,默认为false,如果为true必须重写获取测试集的方法
#!/bin/bash
export BERT_BASE_DIR=bert-models/chinese_L-12_H-768_A-12
export MY_DATASET=data
export OUTPUT_PATH=output
export TASK_NAME=senti
nohup /home/peng/anaconda3/envs/bert/bin/python run_classifier.py \
--data_dir=$MY_DATASET \
--task_name=$TASK_NAME \
--output_dir=$OUTPUT_PATH \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--do_train=True \
--do_eval=True \
--do_predict=True \
--max_seq_length=128 \
--train_batch_size=16 \
--learning_rate=5e-5 \
--num_train_epochs=2.0 \
>train.log 2>&1 &
7. 训练结果
训练结果在自定义OUTPUT_PATH/eval_results.txt
中,
eval_accuracy = 0.84942085
eval_loss = 0.3728643
global_step = 776
loss = 0.3766538
测试集的预测结果在OUTPUT_PATH/test_results.tsv
中,
前5条数据格式如下,两列数据分别表示[0, 1]概率:
0.012343313 0.9876567
0.9637287 0.03627124
0.3622907 0.6377093
0.0120654255 0.9879346
0.41722867 0.5827713
test.csv数据集中前5条如下:
8. 参考资料
- 快速用BERT实现情感分析
- 基于Bert的中文情感分析代码及分析
- BERT_中文情感分类操作及代码
- 基于Bert的中文情感分析代码及分析
- BERT文本分类实战