基于Flask Web框架提供Pytorch 模型在线服务

本章节中,我们将使用Flask 部署一个Pytorch模型,并为模型预测提供一个REST API 接口。下面,我们部署一个预训练好的模型DenseNet 121,该模型用于检测图片

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

app = Flask(__name__)

# 加载lable和类别名称关系
imagenet_class_index = json.load(open('imagenet_class_index.json'))

图片特征提取

def transform_image(image_bytes):
    """
        DenseNet model requires the image to be of 3 channel RGB image of size 224 x 224.
        下面对原始图片的预处理

        1. transforms.Resize 改变原始图片的大小

        2.transforms.CenterCrop 生成一个CenterCrop类的对象,用来将图片从中心裁剪成224*224
        将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,
        在这种情况下,切出来的图片的形状是正方形。

        3. transforms.ToTensor 转为tensor,在GPU上运行

        4. transforms.Normalize 参数处理功能描述:图片标准化处理
        We will also normalise the image tensor with the required mean and standard deviation values.
        You can read more about it here.
    """
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])

    # 返回 PIL.Image.Image 对象
    image = Image.open(io.BytesIO(image_bytes))
    px = my_transforms(image).unsqueeze(0) # 通过unsqueeze(0) 后:torch.Size([3, 224, 224])->torch.Size([1, 3, 224, 224])

    return px

接下来,我们来提取一张图片的特征,观察下返回的tensor 的结果

with open("cat_pic.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor.shape) # torch.Size([1, 3, 224, 224])  表示: 一张图片+3个渠道+长度224+宽度224 的数组
    print(tensor)
torch.Size([1, 3, 224, 224])
tensor([[[[-0.6109, -0.5424, -0.4568,  ..., -1.6727, -1.6898, -1.7240],
          [-0.5596, -0.4397, -0.3883,  ..., -1.7240, -1.7583, -1.7754],
          [-0.5253, -0.3883, -0.3369,  ..., -1.7583, -1.7925, -1.7925],
          ...,
          [ 0.9132,  0.7591,  0.6221,  ...,  1.8722,  1.9235,  1.9749],
          [ 0.8104,  0.7077,  0.3481,  ...,  1.8550,  1.8550,  1.8722],
          [ 0.3481, -0.0116, -0.3883,  ...,  1.8722,  1.8550,  1.8379]],

         [[-0.4951, -0.4426, -0.3725,  ..., -1.2654, -1.3004, -1.3004],
          [-0.4076, -0.3375, -0.2850,  ..., -1.3354, -1.3880, -1.4055],
          [-0.3725, -0.2850, -0.2500,  ..., -1.3880, -1.4580, -1.4755],
          ...,
          [ 0.2227,  0.0126, -0.2150,  ...,  1.7283,  1.7983,  1.8859],
          [-0.0049, -0.1275, -0.4076,  ...,  1.7108,  1.6933,  1.6933],
          [-0.2500, -0.5301, -0.7752,  ...,  1.7108,  1.6933,  1.6758]],

         [[-1.0027, -0.9504, -0.8981,  ..., -1.3861, -1.4036, -1.4210],
          [-0.9156, -0.8807, -0.8284,  ..., -1.4384, -1.4384, -1.4559],
          [-0.9330, -0.8807, -0.8110,  ..., -1.4733, -1.4907, -1.5081],
          ...,
          [-0.3753, -0.5670, -0.7587,  ...,  1.5942,  1.7163,  1.9080],
          [-0.6018, -0.7238, -0.8981,  ...,  1.5420,  1.5071,  1.5245],
          [-0.7413, -0.8633, -1.0201,  ...,  1.5420,  1.4897,  1.4722]]]])

在线服务预测

from torchvision import models

# Make sure to pass `pretrained` as `True` to use the pretrained weights:
model = models.densenet121(pretrained=True)
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()

# 加载lable和类别名称关系
imagenet_class_index = json.load(open('imagenet_class_index.json'))

def get_prediction(image_bytes):
    # 数据预处理-图片特征提取
    tensor = transform_image(image_bytes=image_bytes)
    # 模型预测
    outputs = model.forward(tensor)
    # 输出结果可能性最大的一个数值
    _, y_hat = outputs.max(1)
    # tensor 转为一个数值类型数据
    predicted_idx = str(y_hat.item())
    # 获取名称和index
    return imagenet_class_index[predicted_idx]
with open("cat_pic.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))
['n02127052', 'lynx']

基于Flask 提供 API Server

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

app = Flask(__name__)

# 加载lable和类别名称关系
imagenet_class_index = json.load(open('imagenet_class_index.json'))

#服务器启动加载,
model = models.densenet121(pretrained=True)
#dropout and batch normalization layers to evaluation mode
model.eval()


def transform_image(image_bytes):
    """
        DenseNet model requires the image to be of 3 channel RGB image of size 224 x 224.
        下面对原始图片的预处理

        1. transforms.Resize 改变原始图片的大小

        2.transforms.CenterCrop 生成一个CenterCrop类的对象,用来将图片从中心裁剪成224*224
        将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,
        在这种情况下,切出来的图片的形状是正方形。

        3. transforms.ToTensor 转为tensor,在GPU上运行

        4. transforms.Normalize 参数处理功能描述:图片标准化处理
        We will also normalise the image tensor with the required mean and standard deviation values.
        You can read more about it here.
    """
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])

    # 返回 PIL.Image.Image 对象
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    # 数据预处理-图片特征提取
    tensor = transform_image(image_bytes=image_bytes)
    # 模型预测
    outputs = model.forward(tensor)
    # 输出结果可能性最大的一个数值
    _, y_hat = outputs.max(1)
    # tensor 转为一个数值类型数据
    predicted_idx = str(y_hat.item())
    # 获取名称和index
    return imagenet_class_index[predicted_idx]


@app.route('/')
def hello():
    return "Hello World!"


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':

    app.run()
 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   WARNING: Do not use the development server in a production environment.
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)

上述代码:app.py ,可以命令行执行


1. 启动服务
$ FLASK_ENV=development FLASK_APP=app.py flask run
 * Serving Flask app "app.py" (lazy loading)
 * Environment: development
 * Debug mode: on
 * Restarting with stat
 * Debugger is active!
 * Debugger PIN: 276-234-659
 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)



2. 在线预测
curl -X POST -F file=@cat_pic.jpeg http://127.0.0.1:5000/predict
{
  "class_id": "n02127052",
  "class_name": "lynx"
}

参考资料

[1] 基于Flask 部署Pytorch模型

https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html?highlight=serving

[2] Pytorch 提供基于ImageNet预训练模型

https://pytorch.org/docs/stable/torchvision/models.html

[3] Pytorch 模型保持和加载

https://pytorch.org/tutorials/beginner/saving_loading_models.html

你可能感兴趣的:(深度学习与图像处理)