代码已上传至github https://github.com/danan0755/Bert_Classifier
数据来源cnews,可以通过百度云下载
链接:https://pan.baidu.com/s/1LzTidW_LrdYMokN---Nyag
提取码:zejw
数据格式如下:
bert中文预训练模型下载地址:
链接:https://pan.baidu.com/s/14JcQXIBSaWyY7bRWdJW7yg
提取码:mvtl
复制run_classifier.py,命名为run_cnews_classifier.py。添加自定义的Processor
class MyProcessor(DataProcessor):
def read_txt(self, data_dir, flag):
with open(data_dir, 'r', encoding='utf-8') as f:
lines = f.readlines()
random.seed(0)
random.shuffle(lines)
# 取少量数据做训练
if flag == "train":
lines = lines[0:5000]
elif flag == "dev":
lines = lines[0:500]
elif flag == "test":
lines = lines[0:100]
return lines
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self.read_txt(os.path.join(data_dir, "cnews.train.txt"), "train"), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self.read_txt(os.path.join(data_dir, "cnews.val.txt"), "dev"), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self.read_txt(os.path.join(data_dir, "cnews.test.txt"), "test"), "test")
def get_labels(self):
"""See base class."""
return ["体育", "娱乐", "家居", "房产", "教育", "时尚", "时政", "游戏", "科技", "财经"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
split_line = line.strip().split("\t")
text_a = tokenization.convert_to_unicode(split_line[1])
text_b = None
if set_type == "test":
label = "体育"
else:
label = tokenization.convert_to_unicode(split_line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
main方法里添加自定义的Processor
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"cnews": MyProcessor
}
训练运行命令
python run_cnews_classifier.py --task_name=cnews --do_train=true --do_eval=true --do_predict=false --data_dir=cnews --vocab_file=pretrained_model/chinese_L-12_H-768_A-12/vocab.txt --bert_config_file=pretrained_model/chinese_L-12_H-768_A-12/bert_config.json --init_checkpoint=pretrained_model/chinese_L-12_H-768_A-12/bert_model.ckpt --train_batch_size=32 --max_seq_length=128 --output_dir=model
运行测试命令
python run_cnews_classifier.py --task_name=cnews --do_train=false --do_eval=false --do_predict=true --data_dir=cnews --vocab_file=pretrained_model/chinese_L-12_H-768_A-12/vocab.txt --bert_config_file=pretrained_model/chinese_L-12_H-768_A-12/bert_config.json --init_checkpoint=model/model.ckpt-468 --max_seq_length=128 --output_dir=result
结果
INFO:tensorflow: eval_accuracy = 0.93386775
INFO:tensorflow: eval_loss = 0.33081177
INFO:tensorflow: global_step = 468
INFO:tensorflow: loss = 0.3427003