Web服务部署深度学习模型

公众号关注 “ML_NLP”
设为 “星标”,重磅干货,第一时间送达!

在这里插入图片描述

机器学习算法与自然语言处理出品
@公众号原创专栏作者 刘聪NLP
学校 | 中国药科大学 药学信息学硕士
知乎专栏 | 自然语言处理相关论文

本文的目的是介绍如何使用Web服务快速部署深度学习模型,虽然TF有TFserving可以进行模型部署,但是对于Pytorch无能为力(如果要使用的话需要把torch模型进行转换,有些麻烦);因此,本文在这里介绍一种使用Web服务部署深度学习的方法(简单有效,不喜勿喷)。

本文以简单的新闻分类模型来举例,模型:BERT;数据来源:清华新闻语料(地址:

THUCTC: 一个高效的中文文本分类工具),清华新闻语料共有14个类别,分别是体育,娱乐,家居,彩票,房产,教育,时尚,时政,星座,游戏,社会,科技,股票和财经。为了快速训练模型,本人在每个类别中分别随机挑选1000个作为训练集,200个作为验证集。数据预处理、模型训练和pb模型保存代码见:新闻分类模型训练github地址。(非重点,不过多介绍了,github上有详细的使用说明,有问题可留言。)

为了使web服务部署变得简洁,因此本人构造一个方法类,方便加载pb模型,对传入文本进行数据预处理以及进行模型预测。

模型初始化代码如下:

import bert_tokenization
import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy as np
import os

class ClassificationModel(object):
    def __init__(self):
        self.tokenizer = None
        self.sess = None
        self.is_train = None
        self.input_ids = None
        self.input_mask = None
        self.segment_ids = None
        self.predictions = None
        self.max_seq_length = None
        self.label_dict = ['体育', '娱乐', '家居', '彩票', '房产', '教育', '时尚', '时政', '星座', '游戏', '社会', '科技', '股票', '财经']

其中,tokenizer 为分词器;sess为TF的session模块;is_train、input_ids、input_mask和segment_ids分别是pb模型的输入;predictions为pb模型的输出;max_seq_length为模型的最大输入长度;label_dict为新闻分类标签。
加载pb模型代码如下:

def load_model(self, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
    self.tokenizer = bert_tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_memory_fraction)
    sess_config = tf.ConfigProto(gpu_options=gpu_options)
    self.sess = tf.Session(config=sess_config)
    with gfile.FastGFile(model_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        self.sess.graph.as_default()
        tf.import_graph_def(graph_def, name="")

    self.sess.run(tf.global_variables_initializer())
    self.is_train = self.sess.graph.get_tensor_by_name("input/is_train:0")
    self.input_ids = self.sess.graph.get_tensor_by_name("input/input_ids:0")
    self.input_mask = self.sess.graph.get_tensor_by_name("input/input_mask:0")
    self.segment_ids = self.sess.graph.get_tensor_by_name("input/segment_ids:0")
    self.predictions = self.sess.graph.get_tensor_by_name("output_layer/predictions:0")
    self.max_seq_length = max_seq_length

其中,gpu_id为使用GPU的序号;vocab_file为BERT模型所使用的字典路径;gpu_memory_fraction为使用GPU时所占用的比例;model_path为pb模型的路径;max_seq_length为BERT模型的最大长度。

将传入文本转化成模型所需格式代码如下:

def convert_fearture(self, text):
    max_seq_length = self.max_seq_length
    max_length_context = max_seq_length - 2

    content_token = self.tokenizer.tokenize(text)
    if len(content_token) > max_length_context:
        content_token = content_token[:max_length_context]

    tokens = []
    segment_ids = []
    tokens.append("[CLS]")
    segment_ids.append(0)
    for token in content_token:
        tokens.append(token)
        segment_ids.append(0)
    tokens.append("[SEP]")
    segment_ids.append(0)

    input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)
    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length
    input_ids = np.array(input_ids)
    input_mask = np.array(input_mask)
    segment_ids = np.array(segment_ids)
    return input_ids, input_mask, segment_ids

预测代码如下:

def predict(self, text):
    input_ids_temp, input_mask_temp, segment_ids_temp = self.convert_fearture(text)
    feed = {self.is_train: False,
            self.input_ids: input_ids_temp.reshape(1, self.max_seq_length),
            self.input_mask: input_mask_temp.reshape(1, self.max_seq_length),
            self.segment_ids: segment_ids_temp.reshape(1, self.max_seq_length)}
    [label] = self.sess.run([self.predictions], feed)
    label_name = self.label_dict[label[0]]
    return label[0], label_name

其中,输入是一个新闻文本,输出为类别序号以及对应的标签名称。详细完整代码见github:

ClassificationModel.py文件。


(划重点)上面介绍的都是如何方便简洁地加载模型,下面开始使用web服务挂起模型。通俗地讲,其实本人就是通过flask框架,搭建了一个web服务,来获取外部的输入;并且使用挂载的模型进行预测;最后将预测结果通过web服务传出。

from gevent import monkey
monkey.patch_all()
from flask import Flask, request
from gevent import wsgi
import json
from ClassificationModel import ClassificationModel


def start_sever(http_id, port, gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length):
    model = ClassificationModel()
    model.load_model(gpu_id, vocab_file, gpu_memory_fraction, model_path, max_seq_length)
    print("load model ending!")
    app = Flask(__name__)

    @app.route('/')
    def index():
        return "This is News Classification Model Server"

    @app.route('/news-classification', methods=['Get', 'POST'])
    def response_request():
        if request.method == 'POST':
            text = request.form.get('text')
        else:
            text = request.args.get('text')
        label, label_name = model.predict(text)
        d = {"label": str(label), "label_name": label_name}
        print(d)
        return json.dumps(d, ensure_ascii=False)

    server = wsgi.WSGIServer((str(http_id), port), app)
    server.serve_forever()

其中,http_id为web服务的地址;port为端口号;gpu_id、vocab_file、gpu_memory_fraction、model_path和max_seq_length为上面介绍的加载模型所需要的参数,详细见上文。

index函数用于检验web服务是否畅通。如图1所示。
Web服务部署深度学习模型_第1张图片
response_request函数为响应函数。定义了两种请求数据的方式,get和post。当使用get方法获取web输入时,获取命令为request.args.get(‘text’);当使用post方法获在这里插入代码片取web输入时,获取命令为request.form.get(‘text’)。

当web服务起起来之后,就可以调用啦!!!

浏览器调用如图2所示。
在这里插入图片描述
Code调用如下:

import requests

def http_test(text):
    url = 'http://127.0.0.1:5555/news-classification'
    raw_data = {'text': text}
    res = requests.post(url, raw_data)
    result = res.json()
    return result

if __name__ == "__main__":
    text = "姚明在NBA打球,很强。"
    result = http_test(text)
    print(result["label_name"])

以上就是通过web服务部署深度学习模型的全部内容,喜欢的同学还请多多点赞~~~~~


重磅!忆臻自然语言处理-学术微信交流群已成立
我们为大家整理了李航老师最新书籍的ppt课件

在这里插入图片描述


添加小助手领取,还可以进入官方交流群!

注意:请大家添加时修改备注为 [学校/公司 + 姓名 + 方向]

例如 —— 哈工大+张三+对话系统。

号主,微商请自觉绕道。谢谢!

Alt

Alt

推荐阅读:
PyTorch Cookbook(常用代码段整理合集)
通俗易懂!使用Excel和TF实现Transformer!
深度学习中的多任务学习(Multi-task-learning)——keras实现

在这里插入图片描述

你可能感兴趣的:(自然语言处理,深度学习,计算机技术)