使用BERT-TensorFlow解决法研杯要素识别任务
,该任务其实是一个多标签文本分类任务。模型的具体不是本文重点,故于此不细细展开说明。本文重点阐述如何部署模型。
官方推荐TensorFlow模型在生产环境中提供服务时使用SavedModel格式。SavedModel格式是一种通用的、语言中立的、密闭的、可恢复的TensorFlow模型序列化格式。SavedModel封装了TensorFlow Saver,对于模型服务是一种标准的导出方法。
这里的estimator
部分也忽略,不详细说明,其关键是调用estimator的export_savedmodel
导出SaveModel格式的模型,注意serving_input_fn
的编写。其中的字段与后续POST中的数据字段相对应。
def serving_input_fn():
# 保存模型为SaveModel格式
# 采用最原始的feature方式,输入是feature Tensors。
# 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples
label_ids = tf.placeholder(tf.int32, [None, 20], name='label_ids') # 要素识别任务有20个类别
input_ids = tf.placeholder(tf.int32, [None, cfig.max_seq_length], name='input_ids')
input_mask = tf.placeholder(tf.int32, [None, cfig.max_seq_length], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, cfig.max_seq_length], name='segment_ids')
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'label_ids': label_ids,
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids,
})()
return input_fn
if cfig.do_export:
estimator._export_to_tpu = False
estimator.export_savedmodel(cfig.export_dir, serving_input_fn)
检查模型:
saved_model_cli show --dir save_model/1 --all
先基于Docker拉取tensorflow/serving
镜像(PS:这是CPU版)。再基于镜像,启动容器:
docker run --rm -t -p 8501:8501 -v /home/liujiepeng/MachineComprehension/CAIL2019/ElementsRecognition/bert_tensorflow_multi_label/save_model:/models/cail_elem --name=tfserving_cail -e MODEL_NAME=cail_elem tensorflow/serving:latest
运行结果:
2019-09-21 03:24:48.782137: I tensorflow_serving/model_servers/server.cc:82] Building single TensorFlow model file config: model_name: cail_elem model_base_path: /models/cail_elem
2019-09-21 03:24:48.782580: I tensorflow_serving/model_servers/server_core.cc:462] Adding/updating models.
2019-09-21 03:24:48.782633: I tensorflow_serving/model_servers/server_core.cc:561] (Re-)adding model: cail_elem
2019-09-21 03:24:48.883257: I tensorflow_serving/core/basic_manager.cc:739] Successfully reserved resources to load servable {name: cail_elem version: 1}
2019-09-21 03:24:48.883351: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: cail_elem version: 1}
2019-09-21 03:24:48.883433: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: cail_elem version: 1}
2019-09-21 03:24:48.883530: I external/org_tensorflow/tensorflow/contrib/session_bundle/bundle_shim.cc:363] Attempting to load native SavedModelBundle in bundle-shim from: /models/cail_elem/1
2019-09-21 03:24:48.883581: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: /models/cail_elem/1
2019-09-21 03:24:48.917199: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { serve }
2019-09-21 03:24:48.948563: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 AVX512F FMA
2019-09-21 03:24:49.028645: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:202] Restoring SavedModel bundle.
2019-09-21 03:24:49.497106: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:151] Running initialization op on SavedModel bundle at path: /models/cail_elem/1
2019-09-21 03:24:49.543113: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:311] SavedModel load for tags { serve }; Status: success. Took 659522 microseconds.
2019-09-21 03:24:49.543191: I tensorflow_serving/servables/tensorflow/saved_model_warmup.cc:103] No warmup data file found at /models/cail_elem/1/assets.extra/tf_serving_warmup_requests
2019-09-21 03:24:49.543323: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: cail_elem version: 1}
2019-09-21 03:24:49.549907: I tensorflow_serving/model_servers/server.cc:324] Running gRPC ModelServer at 0.0.0.0:8500 ...
[warn] getaddrinfo: address family for nodename not supported
[evhttp_server.cc : 239] RAW: Entering the event loop ...
2019-09-21 03:24:49.557068: I tensorflow_serving/model_servers/server.cc:344] Exporting HTTP/REST API at:localhost:8501 ...
查看正在运行的容器:docker container ls
对原始请求进行封装,构建符合要求的POST请求:
# -*- coding: utf-8 -*-
# @CreatTime : 2019/9/20 11:46
# @Author : JasonLiu
# @FileName: test_tfserving.py
import requests
import json
import tensorflow as tf
import collections
import pdb
import numpy as np
from bert import tokenization
from utils import create_examples_text_list, convert_single_example
def test_request():
label_ids = 20*[0]
input_ids = 512*[1]
input_mask = 512*[1]
segment_ids = 512*[1]
data_dict_temp = {
'label_ids': label_ids,
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids,
}
data_list = []
data_list.append(data_dict_temp)
data = json.dumps({"signature_name": "serving_default", "instances": data_list})
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/cail_elem:predict', data=data, headers=headers)
print(json_response.text)
predictions = json.loads(json_response.text)['predictions']
print(predictions)
def request_from_raw_text():
"""
:return:
"""
BERT_VOCAB = "/home/data1/ftpdata/pretrain_models/bert_tensoflow_version/bert-base-chinese-vocab.txt"
text_list = ["权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。",
"权人宏伟支行及宝成公司共22次向怡天公司催收借款全部本金及利息,均产生诉讼时效中断的法律效力,本案债权未过诉讼时效期间", # LN8
"2012年11月30日,原债权人工行锦州市分行向保证人锦州锅炉有限责任公司发出督促履行保证责任通知书,要求其履行保证责任,"
"2004年11月18日,原债权人工行锦州市分行采用国内挂号信函的方式向保证人锦州锅炉有限责任公司邮寄送达中国工商银行辽宁省分行督促履行保证责任通知书," # LN4
"锦州市凌河区公证处相关公证人员对此过程进行了公证。"
]
data_list = []
tokenizer = tokenization.FullTokenizer(vocab_file=BERT_VOCAB, do_lower_case=True)
predict_examples = create_examples_text_list(text_list)
for (ex_index, example) in enumerate(predict_examples):
feature = convert_single_example(ex_index, example,
512, tokenizer)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
features = {}
features["input_ids"] = feature.input_ids
features["input_mask"] = feature.input_mask
# pdb.set_trace()
features["segment_ids"] = feature.segment_ids
if isinstance(feature.label_ids, list):
label_ids = feature.label_ids
else:
label_ids = feature.label_ids[0]
features["label_ids"] = label_ids
# tf_example = tf.train.Example(features=tf.train.Features(feature=features))
data_list.append(features)
data = json.dumps({"signature_name": "serving_default", "instances": data_list})
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/cail_elem:predict', data=data, headers=headers)
# print(json_response.text)
# pdb.set_trace()
predictions = json.loads(json_response.text)['predictions']
# print(predictions)
for p in range(len(predictions)):
p_list = predictions[p]
label_index = np.argmax(p_list)
print("content={},label={}".format(text_list[p], label_index+1))
print("total number=", len(text_list))
request_from_raw_text()
从运行效率来看,CPU推理上,整体偏慢。运行上述32条任务,耗时:
real 0m17.366s
user 0m1.815s
sys 0m0.997s
那么我们试试采用tensorflow/serving:latest:gpu
版。此时,我们需要特别注意的是,本地NVIDIA 显卡驱动和ensorflow/serving:gpu
版本的匹配问题。
由于机器cuda版本是9.0,而tensorflow/serving:latest-gpu
是对应cuda 10版本。所以,需要从https://hub.docker.com/r/tensorflow/serving/tags/
找到合适的gpu版本。最终发现tensorflow/serving:1.12.3-gpu是可以与机器适配的。所以,拉取该镜像:docker pull tensorflow/serving:1.12.3-gpu
运行容器:
nvidia-docker run -t --rm -p 8501:8501 -v /home/liujiepeng/MachineComprehension/CAIL2019/ElementsRecognition/bert_tensorflow_multi_label/save_model:/models/cail_elem -e MODEL_NAME=cail_elem tensorflow/serving:1.12.3-gpu
即可GPU方式启动服务。
再测试,发现运行32条任务的耗时如下:
real 0m5.574s
user 0m2.084s
sys 0m0.902s
提速明显。