PyTorch+Flask+Gunicorn 部署深度模型服务

目录

  • 简介
  • 项目实现
    • 模型实现
    • 部署Flask服务
    • Gunicorn加速
    • 性能测试
      • Python 多线程+requests
      • ab工具
    • (可选)HTML网页端
  • 总结
  • 参考文献

简介

初入职场,对于训练完成的模型,对如何被应用到实际的生产环境中产生了疑问,如果要每次都手动向服务器上传一批离线的数据,test一下得到结果,再手动把结果反馈出去就太麻烦了,那么有什么办法可以使模型能让业务端方便地调用呢?

不考虑模型压缩这些问题,要将一个训练完的深度学习模型供生产环境使用,最简单的方法就是写成一个Web服务放在服务器后台,客户端通过HTTP发送请求给服务端,并将数据上传给服务端,服务端对数据进行处理和计算,同样通过HTTP将结果返回给客户端。这个过程中有两个地方是可以深入优化:

  1. 从服务层面:需要赋予Web服务并发处理的能力,以应对多个客户端同时发送请求的场景。当然这是开发人员擅长的东西,暂时没有进行太深入的研究。
  2. 从模型层面:需要优化模型的计算速度,涉及模型压缩和加速技术,这是算法工程师需要解决的问题,但这是一个很大的命题了,将在其他文章中介绍。

文章中不会涉及太复杂的内容,仅仅是构建了一个可顺利运行的简单服务,算是新手入坑之作。对于新手来说,PyTorch应该不陌生,但需要Flask、Gunicorn、HTML以及HTTP的一些基本知识,我也只是到w3school学习了点基础,但是对于这个简单的项目也够用了。

对于Flask,你需要了解:

  1. 如何将客户端的url和服务端的函数方法绑定在一起
  2. 如何接收和传送文件
  3. 怎么传送HTML网页或从HTML网页中提取表单(可选)

对于HTTP:

  1. 需要了解HTTP基本的运作模式
  2. 需要了解HTTP的POST和GET方法,以及返回码的意义
  3. 需要了解HTTP常见的头部信息,最关键是content-type

对于Gunicorn,你只需要了解怎么用它启动Flask,以及它的基本配置方法

还有一些工具使,本文章涉及以下知识:

  1. 终端上发起HTTP请求的工具curl
  2. 用python发起HTTP请求的库requests
  3. 压力测试工具ab
  4. HTML中关于HTTP的设置

项目实现

架构

classifier\
	|————base_classifier.py
	|————resnet50.py
	|————__init__.py
	|————class_name.txt
template\
	|————upload.html
material\
	|————dog.png
classifier_server.py
gunicorn_config.py
launcher.sh

模型实现

简单起见,本项目选定了ResNet50分类模型作为演示,我们需要实现了一个模型类,对给定输入的图片能够进行预处理,然后送入模型中进行计算得到预测类标:

# 文件classifier/base_classifier.py 定义一个模型类的基类BaseClassifier
import  torchvision.transforms as T
class BaseClassifier():

    def __init__(self):
        pass
    
    def predict(self, x):
        pass

    def define_model(self):
        pass
        
    def preprocess(self, x):
        
        tranforms = T.Compose([
            T.Resize((224,224)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])

        return tranforms(x).unsqueeze(0)
## 定义一个ResnetClassifier,继承自ResnetClassifier,实现了具体方法
from .base_classifier import BaseClassifier
import numpy as np
from torchvision.models import resnet50
from PIL import Image
import torch as t

with open('classifier/class_name.txt','r') as f:
    raw_name = f.read().splitlines()
id_to_class = {
     }
for row in raw_name:
    ids, class_name = row.split(':')
    id_to_class[int(ids)] = class_name
# print(id_to_class)

class ResnetClassifier(BaseClassifier):

    def __init__(self):
        super().__init__()
        self.model = self.define_model()

    def predict(self, x):
        x_tensor = self.preprocess(x)
        print(x_tensor.shape)
        with t.no_grad():
            output = self.model(x_tensor)
            predict_class = np.argmax(output.numpy())
        return id_to_class[predict_class]

		# 定义模型结构,这里直接采用了torchvision.model提供的预训练模型
    def define_model(self):
        model = resnet50(pretrained=True)
        model.eval()
        return model
        
# 模块测试
if __name__ == "__main__":
    classfier = ResnetClassifier()
    
    img = Image.open('material/dog.png').convert('RGB')
    img.show()

    print(classfier.predict(img))

如果有其他的模型需要部署,继承基类BaseClassifier,然后实现其中的成员函数,也许有的模型更加复杂,那么可以在子类上添加额外的预处理或后处理函数。

部署Flask服务

接下来就是最关键的一步,实现了一个predict函数,将让一个url:’/predict’绑定到到该函数,这样通过requests或者curl发送POST请求给http://127.0.0.1:port/predict,就可以得到返回的预测结果。

首先,你需要明白什么叫作将url: ‘/pridict’绑定到predict函数。举个简单的例子,在服务端的程序中,我们新建一个Flask实例,并且实现了一个简单的函数yyy,再在函数上方用 @app.route(’/xxx’,…)修饰该函数,这样客户端在访问http://127.0.0.1:port/xxx时,服务端就会调用yyy函数,如果yyy函数有return一些变量,客户端也能接收到该变量。

还有一点需要注意的是,xxx和yyy你想怎么命名都行,并不需要同名。@app.route还可以指定客户端请求的方式,可以指定只用GET或者POST或者都可以用。

from flask import Flask

app = Flask(__name__)
@app.route('/xxx', methods = ['GET','POST'])
def yyy():
	print("hello")

其次,明确这个predict函数需要做什么?

  1. 确认请求方法是POST,因为需要客户端上传图片,必须用POST请求
  2. 取出表单中的文件,用io库中BytesIO类和PIL.Image配合将字节流转化为图片
  3. 将图片送入到classifier中分类得到预测结果
  4. 返回预测结果给客户端
# 文件 classifier_server.py
from classifier.resnet50 import ResnetClassifier
from flask import Flask, render_template, request
from werkzeug.utils import secure_filename
import os
import io
from PIL import Image
import logging

app = Flask(__name__)
classifier = ResnetClassifier()

@app.route('/predict', methods = ['POST'])
def predict():
    if request.method == 'POST':
        f = request.files['file']
        f = f.read()
        byte_stream = io.BytesIO(f)
        img = Image.open(byte_stream).convert('RGB')
        result = classifier.predict(img)
        # app.logger.debug('This is an Debug Info')
        return str(result)
		elif request.method == 'GET':
				return "unsupport method GET"

if __name__ == '__main__':
    app.run(debug=True)

通过下面指令启动服务:

python3 classifier_server.py

然后可以通过requests库的POST方法,或者curl向服务器发送请求,上传一张图片,询问该图片的分类,具体的实现方式如下:

  1. 第一种方法通过requests库
import requests

files = {
     'file': open('material/dog.png','rb')}

res = requests.post('http://127.0.0.1:5000/predict', files=files)

print(res.text)

# 打印:'Tibetan mastiff'
# 预测是藏獒,但这只汪汪是藏獒吗?不去在意这些细节了.
  1. 第二种方法通过curl,在命令行中输入:
curl -F 'file=@material/dog.png' http://127.0.0.1:5000/predict

# 打印:'Tibetan mastiff'
# 和上面一致

Gunicorn加速

Flask是一个非常简单的web框架,它具备一个web服务的基本元素,但是它无法用在生产环境,因为它多线程和多进程的能力堪忧,一般需要配合Gunicorn一起用,Gunicorn赋予了Flask处理多线程和多进程的能力,接管了任务调度等工作,它的使用方法也非常简单:

首先需要在项目目录下新建一个gunicorn的配置文件gunicorn_config.py

# 文件 gunicorn_config.py
loglevel = 'debug' #日志级别 debug info warning error critical
bind = "127.0.0.1:5000" #绑定地址和端口 # utils.get_host_ip2()+':8000'
daemon = False # 是否以守护进程启动
# workers = multiprocessing.cpu_count() * 2 + 1 #启动进程数
workers = 4 #10
worker_class = 'gthread' #工作模式 切记不能使用 gevent ,会拦截内部flask发出的请求
threads = 4 #每个工作者线程数
worker_connections = 2000 # 最大并发量
pidfile = "./log/gunicorn.pid" # pid 文件
accesslog = "./log/access.log" #访问日志目录
errorlog = "./log/debug.log" #出错日志
graceful_timeout = 300
timeout = 300 #reload worker after slicent 3 secs
# preload_app=True #是否预加载app,加快启动速度
# reload=True # 代码更新自动重启

然后新建一个launch.sh文件,用gunicorn启动Flask:

# launch.sh文件
if [ "$1" = "daemon" ];then
    gunicorn -c gunicorn_config.py --daemon classifier_server:app #守护进程启动
else
	gunicorn -c gunicorn_config.py classifier_server:app #非守护进程启动
fi

最后在命令行运行launch.sh即可以4进程4线程的方式运行Flask服务,我们同样可以用上一小结的requests库或者curl发送请求,测试服务是否正常,代码和指令的写法一样,这里就不再赘述了。

性能测试

Python 多线程+requests

最简单的测试方法就是用python的多线程结合requests库的post函数,向服务器同时发起多个请求,你可以不断提高NUM_REQUESTS,看看服务器会不会出现什么问题:

# USAGE
# python stress_test.py

from threading import Thread
import requests
import time

# initialize the API endpoint URL along with the input
# image path
API_URL = "http://172.19.8.88:5000/predict/"
IMAGE_PATH = "material/dog.png"

# initialize the number of requests for the stress test along with
# the sleep amount between requests
NUM_REQUESTS = 500
SLEEP_COUNT = 0.05

def call_predict_endpoint(n):
    # load the input image and construct the payload for the request
    payload = {
     'file': open('material/dog.png','rb')}

    # submit the request
    r = requests.post(API_URL, files=payload)

    # ensure the request was sucessful
    if r.text=='Tibetan mastiff':
        print("[INFO] thread {} OK".format(n))

    # otherwise, the request failed
    else:
        print("[INFO] thread {} FAILED".format(n))

# loop over the number of threads
for i in range(0, NUM_REQUESTS):
    # start a new thread to call the API
    t = Thread(target=call_predict_endpoint, args=(i,))
    t.daemon = True
    t.start()
    time.sleep(SLEEP_COUNT)

# insert a long sleep so we can wait until the server is finished
# processing the images
time.sleep(300)

ab工具

另一种相对更专业的方法是用ab,ab一款Web服务器压力测试小工具,但ab不能原生地支持post图片,需要自己为图片构造HTTP头部,具体地方法是用vim或notepad++打开图片文件,然后在文件内加上下面

--1234567890
Content-Disposition: form-data; name="file"; filename="dog.png" 
Content-type: image/png

[文件原始的内容]
--1234567890

关于这些内容的解释来自(https://blog.csdn.net/weixin_39494902/article/details/109538560):

1234567890:这是边界标识,你可以替换成任何字符串,只要和后面边界信息保持一致就好了;
Content-Disposition:这个不用做修改;
name:这里影响的是服务接收后的文件解析的key,如下图;
filename:上传后的一些文件信息,实际用处不大,随便起一个名字;
Content-type:上传的文件内容的类型,jpg格式图片所以使用image/jpeg;文章最后我会附上所有对应列表;

有两点需要注意:

  1. 直接用vim或notepad++打开图片文件,不要尝试先修改文件后缀成.txt后再打开,亲测这样会导致之后PIL认不得图片,另外,修改后的文件后缀可以任意选择。
  2. 一定要严格按照上面的格式!多一行少一行都不行!开头添加的内容要与文件原始内容空开一行,结尾添加的内容则不要和原始内容空开一行。

在这之后,通过下面指令可以对服务器发起并发测试,发起1000个请求,并发数目为10:

ab -n 1000 -c 10 -p material/dog.txt -T "multipart/form-data; boundary=1234567890" http://localhost:5000/predict/

当然这里客户端和服务端都是同一个机器,实际情况可能并不是这样,这里只是做了个演示。

然后我们比较一下直接启动Flask和用Gunicorn启动Flask的差距:

Flask

Server Software:        Werkzeug/1.0.1
Server Hostname:        localhost
Server Port:            5000

Document Path:          /predict/
Document Length:        15 bytes

Concurrency Level:      10
Time taken for tests:   79.343 seconds    # 测试时间79.43秒
Complete requests:      1000
Failed requests:        0
Total transferred:      168000 bytes
Total body sent:        1397200000
HTML transferred:       15000 bytes
Requests per second:    12.60 [#/sec] (mean)   # 吞吐率,每秒处理12.6个请求
Time per request:       793.426 [ms] (mean)    # 用户平均等待时间为 793.426ms
Time per request:       79.343 [ms] (mean, across all concurrent requests) # 服务器平均等待时间为79.343ms
Transfer rate:          2.07 [Kbytes/sec] received
                        17196.99 kb/s sent
                        17199.05 kb/s total

Connection Times (ms)
              min  mean[+/-sd] median   max
Connect:        0    1   4.3      0      64
Processing:   380  790 218.6    758    1677
Waiting:      370  779 218.7    746    1667
Total:        381  791 219.7    760    1678

Percentage of the requests served within a certain time (ms)
  50%    760    # 50%的请求在760ms内完成
  66%    832
  75%    889
  80%    919
  90%   1047    # 90%的请求在1047ms内完成
  95%   1244
  98%   1464
  99%   1519    # 99%的请求在1519ms内完成
 100%   1678 (longest request)

Gunicorn+Flask:


Server Software:        gunicorn/20.0.4
Server Hostname:        localhost
Server Port:            5000

Document Path:          /predict/
Document Length:        15 bytes

Concurrency Level:      10
Time taken for tests:   60.444 seconds   # 测试完成时间60秒
Complete requests:      1000
Failed requests:        0
Total transferred:      175000 bytes
Total body sent:        1397200000
HTML transferred:       15000 bytes
Requests per second:    16.54 [#/sec] (mean)   #吞吐率,平均每秒处理16.54个请求
Time per request:       604.442 [ms] (mean)     # 用户平均等待时间604.442ms
Time per request:       60.444 [ms] (mean, across all concurrent requests)  # 服务器平均请求等待时间60.44ms
Transfer rate:          2.83 [Kbytes/sec] received
                        22573.76 kb/s sent
                        22576.59 kb/s total

Connection Times (ms)
              min  mean[+/-sd] median   max
Connect:        0    2   5.4      0      59
Processing:   225  601 183.3    576    1336
Waiting:      214  589 182.7    563    1324
Total:        226  603 183.1    577    1336

Percentage of the requests served within a certain time (ms)
  50%    577       # 50%的请求在760ms内完成
  66%    630
  75%    675
  80%    709
  90%    819       # 90%的请求在819ms内完成
  95%   1030
  98%   1132
  99%   1168       # 90%的请求在1168ms内完成
 100%   1336 (longest request)

直接启动Flask时,吞吐率为每秒处理12.6个线程,低于Gunicorn的16.54,用户平均等待时间为793.426ms,长于Gunicorn的604.442ms,服务器平均等待时间也是Gunicorn更佳。总体来说,用Gunicorn启动Flask是一个不错的选择,反正配置也不太复杂,能用则用。

(可选)HTML网页端

既然都上Web服务了,那么再做个网页,这样用户通过在浏览器中访问http://127.0.0.1:5000/upload,然后上传图片得到预测结果了。不过,这涉及到HTML网页的编写了,这里给一个很简单的例子,掌握一些HTML的基本语句就能写了,只是界面非常丑陋,因为没有设计相应的CSS样式,不过这些不是重点。

新建了一个upload.html文件,核心部分就是中间的表单form,表单内有两个输入项,一个是接受.jpg或者.png文件,另一个是提交按钮submit,点击submit,就会以POST方法,向action所指定的url:"http://127.0.0.1:5000/predict"发送图片内容。

<html>
<head>
  <title>File Uploadtitle>
head>
<body>
    
    <form action="http://127.0.0.1:5000/predict" method="POST" enctype="multipart/form-data">
        <input type="file" name="file" accept=".jpg,.png" />
        <input type="submit" />
    form>
body>
html>

然后在classifier_server.py中将url:’/upload’和函数upload_file绑定在一起,在upload_file函数内,启用upload.html网页

@app.route('/upload')
def upload_file():
   return render_template('upload.html')

这样我们访问"http:127.0.0.1:5000/upload"时就会打开一个upload网页,要求我们上传图片:

PyTorch+Flask+Gunicorn 部署深度模型服务_第1张图片

我们上传完图片后点击提交后,就能获得预测结果:

PyTorch+Flask+Gunicorn 部署深度模型服务_第2张图片

总结

以上就是将模型部署成一个Web服务的全过程,但其实还有许多可以完善的细节,比如:

  1. 想要部署的模型也许比较复杂,但并不涉及web框架,你可以自己实现模型预测的逻辑,留出一个预测接口即可供Web框架调用即可,即BaseClassifier中的predict方法。
  2. 需要客户端上传的不只是数据,可能还有一些配置,比如说我想设定预测的阈值等,那么还需要客户端用Get方法或者Post方法传送这些配置参数,可以是用Post传送表单的方法,和图片一起传送过来,然后服务端通过Flask的request类取出数据。
  3. 增加nginx或Apache组件以及redis组件,构建一个完整的系统(这里先埋一个坑,日后实现)

参考文献

  1. https://blog.csdn.net/weixin_39494902/article/details/109538560
  2. https://www.jiqizhixin.com/articles/2018-02-12

你可能感兴趣的:(深度学习,flask,深度学习)