机器学习算法与自然语言处理出品
@公众号原创专栏作者 刘聪NLP
学校 | 中国药科大学 药学信息学硕士
知乎专栏 | 自然语言处理相关论文
本文的目的是介绍如何使用Web服务快速部署深度学习模型,虽然TF有TFserving可以进行模型部署,但是对于Pytorch无能为力(如果要使用的话需要把torch模型进行转换,有些麻烦);因此,本文在这里介绍一种使用Web服务部署深度学习的方法(简单有效,不喜勿喷)。
本文以简单的新闻分类模型来举例,模型:BERT;数据来源:清华新闻语料(地址:
THUCTC: 一个高效的中文文本分类工具),清华新闻语料共有14个类别,分别是体育,娱乐,家居,彩票,房产,教育,时尚,时政,星座,游戏,社会,科技,股票和财经。为了快速训练模型,本人在每个类别中分别随机挑选1000个作为训练集,200个作为验证集。数据预处理、模型训练和pb模型保存代码见:新闻分类模型训练github地址。(非重点,不过多介绍了,github上有详细的使用说明,有问题可留言。)
为了使web服务部署变得简洁,因此本人构造一个方法类,方便加载pb模型,对传入文本进行数据预处理以及进行模型预测。
模型初始化代码如下:
import bert_tokenization
import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy as np
import os
class ClassificationModel(object):
def __init__(self):
self.tokenizer = None
self.sess = None
self.is_train = None
self.input_ids = None
self.input_mask = None
self.segment_ids = None
self.predictions = None
self.max_seq_length = None
self.label_dict = ['体育', '娱乐', '家居', '彩票', '房产', '教育', '时尚', '时政', '星座', '游戏', '社会', '科技', '股票', '财经']
其中,tokenizer 为分词器;sess为TF的session模块;is_train、input_ids、input_mask和segment_ids分别是pb模型的输入;predictions为pb模型的输出;max_seq_length为模型的最大输入长度;label_dict为新闻分类标签。
加载pb模型代码如下:
def load_model(self, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
self.tokenizer = bert_tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)
sess_config = tf.ConfigProto(gpu_options=gpu_options)
self.sess = tf.Session(config=sess_config)
with gfile.FastGFile(model_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
self.sess.graph.as_default()
tf.import_graph_def(graph_def, name="")
self.sess.run(tf.global_variables_initializer())
self.is_train = self.sess.graph.get_tensor_by_name("input/is_train:0")
self.input_ids = self.sess.graph.get_tensor_by_name("input/input_ids:0")
self.input_mask = self.sess.graph.get_tensor_by_name("input/input_mask:0")
self.segment_ids = self.sess.graph.get_tensor_by_name("input/segment_ids:0")
self.predictions = self.sess.graph.get_tensor_by_name("output_layer/predictions:0")
self.max_seq_length = max_seq_length
其中,gpu_id为使用GPU的序号;vocab_file为BERT模型所使用的字典路径;gpu_memory_fraction为使用GPU时所占用的比例;model_path为pb模型的路径;max_seq_length为BERT模型的最大长度。
将传入文本转化成模型所需格式代码如下:
def convert_fearture(self, text):
max_seq_length = self.max_seq_length
max_length_context = max_seq_length - 2
content_token = self.tokenizer.tokenize(text)
if len(content_token) > max_length_context:
content_token = content_token[:max_length_context]
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in content_token:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
input_ids = np.array(input_ids)
input_mask = np.array(input_mask)
segment_ids = np.array(segment_ids)
return input_ids, input_mask, segment_ids
预测代码如下:
def predict(self, text):
input_ids_temp, input_mask_temp, segment_ids_temp = self.convert_fearture(text)
feed = {self.is_train: False,
self.input_ids: input_ids_temp.reshape(1, self.max_seq_length),
self.input_mask: input_mask_temp.reshape(1, self.max_seq_length),
self.segment_ids: segment_ids_temp.reshape(1, self.max_seq_length)}
[label] = self.sess.run([self.predictions], feed)
label_name = self.label_dict[label[0]]
return label[0], label_name
其中,输入是一个新闻文本,输出为类别序号以及对应的标签名称。详细完整代码见github:
ClassificationModel.py文件。
(划重点)上面介绍的都是如何方便简洁地加载模型,下面开始使用web服务挂起模型。通俗地讲,其实本人就是通过flask框架,搭建了一个web服务,来获取外部的输入;并且使用挂载的模型进行预测;最后将预测结果通过web服务传出。
from gevent import monkey
monkey.patch_all()
from flask import Flask, request
from gevent import wsgi
import json
from ClassificationModel import ClassificationModel
def start_sever(http_id, port, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):
model = ClassificationModel()
model.load_model(gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length)
print("load model ending!")
app = Flask(__name__)
@app.route('/')
def index():
return "This is News Classification Model Server"
@app.route('/news-classification', methods=['Get', 'POST'])
def response_request():
if request.method == 'POST':
text = request.form.get('text')
else:
text = request.args.get('text')
label, label_name = model.predict(text)
d = {"label": str(label), "label_name": label_name}
print(d)
return json.dumps(d, ensure_ascii=False)
server = wsgi.WSGIServer((str(http_id), port), app)
server.serve_forever()
其中,http_id为web服务的地址;port为端口号;gpu_id、vocab_file、gpu_memory_fraction、model_path和max_seq_length为上面介绍的加载模型所需要的参数,详细见上文。
index函数用于检验web服务是否畅通。如图1所示。
response_request函数为响应函数。定义了两种请求数据的方式,get和post。当使用get方法获取web输入时,获取命令为request.args.get(‘text’);当使用post方法获在这里插入代码片
取web输入时,获取命令为request.form.get(‘text’)。
当web服务起起来之后,就可以调用啦!!!
import requests
def http_test(text):
url = 'http://127.0.0.1:5555/news-classification'
raw_data = {'text': text}
res = requests.post(url, raw_data)
result = res.json()
return result
if __name__ == "__main__":
text = "姚明在NBA打球,很强。"
result = http_test(text)
print(result["label_name"])
以上就是通过web服务部署深度学习模型的全部内容,喜欢的同学还请多多点赞~~~~~