本章节中,我们将使用Flask 部署一个Pytorch模型,并为模型预测提供一个REST API 接口。下面,我们部署一个预训练好的模型DenseNet 121,该模型用于检测图片
import io
import json
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__)
# 加载lable和类别名称关系
imagenet_class_index = json.load(open('imagenet_class_index.json'))
def transform_image(image_bytes):
"""
DenseNet model requires the image to be of 3 channel RGB image of size 224 x 224.
下面对原始图片的预处理
1. transforms.Resize 改变原始图片的大小
2.transforms.CenterCrop 生成一个CenterCrop类的对象,用来将图片从中心裁剪成224*224
将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,
在这种情况下,切出来的图片的形状是正方形。
3. transforms.ToTensor 转为tensor,在GPU上运行
4. transforms.Normalize 参数处理功能描述:图片标准化处理
We will also normalise the image tensor with the required mean and standard deviation values.
You can read more about it here.
"""
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
# 返回 PIL.Image.Image 对象
image = Image.open(io.BytesIO(image_bytes))
px = my_transforms(image).unsqueeze(0) # 通过unsqueeze(0) 后:torch.Size([3, 224, 224])->torch.Size([1, 3, 224, 224])
return px
接下来,我们来提取一张图片的特征,观察下返回的tensor 的结果
with open("cat_pic.jpeg", 'rb') as f:
image_bytes = f.read()
tensor = transform_image(image_bytes=image_bytes)
print(tensor.shape) # torch.Size([1, 3, 224, 224]) 表示: 一张图片+3个渠道+长度224+宽度224 的数组
print(tensor)
torch.Size([1, 3, 224, 224])
tensor([[[[-0.6109, -0.5424, -0.4568, ..., -1.6727, -1.6898, -1.7240],
[-0.5596, -0.4397, -0.3883, ..., -1.7240, -1.7583, -1.7754],
[-0.5253, -0.3883, -0.3369, ..., -1.7583, -1.7925, -1.7925],
...,
[ 0.9132, 0.7591, 0.6221, ..., 1.8722, 1.9235, 1.9749],
[ 0.8104, 0.7077, 0.3481, ..., 1.8550, 1.8550, 1.8722],
[ 0.3481, -0.0116, -0.3883, ..., 1.8722, 1.8550, 1.8379]],
[[-0.4951, -0.4426, -0.3725, ..., -1.2654, -1.3004, -1.3004],
[-0.4076, -0.3375, -0.2850, ..., -1.3354, -1.3880, -1.4055],
[-0.3725, -0.2850, -0.2500, ..., -1.3880, -1.4580, -1.4755],
...,
[ 0.2227, 0.0126, -0.2150, ..., 1.7283, 1.7983, 1.8859],
[-0.0049, -0.1275, -0.4076, ..., 1.7108, 1.6933, 1.6933],
[-0.2500, -0.5301, -0.7752, ..., 1.7108, 1.6933, 1.6758]],
[[-1.0027, -0.9504, -0.8981, ..., -1.3861, -1.4036, -1.4210],
[-0.9156, -0.8807, -0.8284, ..., -1.4384, -1.4384, -1.4559],
[-0.9330, -0.8807, -0.8110, ..., -1.4733, -1.4907, -1.5081],
...,
[-0.3753, -0.5670, -0.7587, ..., 1.5942, 1.7163, 1.9080],
[-0.6018, -0.7238, -0.8981, ..., 1.5420, 1.5071, 1.5245],
[-0.7413, -0.8633, -1.0201, ..., 1.5420, 1.4897, 1.4722]]]])
from torchvision import models
# Make sure to pass `pretrained` as `True` to use the pretrained weights:
model = models.densenet121(pretrained=True)
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()
# 加载lable和类别名称关系
imagenet_class_index = json.load(open('imagenet_class_index.json'))
def get_prediction(image_bytes):
# 数据预处理-图片特征提取
tensor = transform_image(image_bytes=image_bytes)
# 模型预测
outputs = model.forward(tensor)
# 输出结果可能性最大的一个数值
_, y_hat = outputs.max(1)
# tensor 转为一个数值类型数据
predicted_idx = str(y_hat.item())
# 获取名称和index
return imagenet_class_index[predicted_idx]
with open("cat_pic.jpeg", 'rb') as f:
image_bytes = f.read()
print(get_prediction(image_bytes=image_bytes))
['n02127052', 'lynx']
import io
import json
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__)
# 加载lable和类别名称关系
imagenet_class_index = json.load(open('imagenet_class_index.json'))
#服务器启动加载,
model = models.densenet121(pretrained=True)
#dropout and batch normalization layers to evaluation mode
model.eval()
def transform_image(image_bytes):
"""
DenseNet model requires the image to be of 3 channel RGB image of size 224 x 224.
下面对原始图片的预处理
1. transforms.Resize 改变原始图片的大小
2.transforms.CenterCrop 生成一个CenterCrop类的对象,用来将图片从中心裁剪成224*224
将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,
在这种情况下,切出来的图片的形状是正方形。
3. transforms.ToTensor 转为tensor,在GPU上运行
4. transforms.Normalize 参数处理功能描述:图片标准化处理
We will also normalise the image tensor with the required mean and standard deviation values.
You can read more about it here.
"""
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
# 返回 PIL.Image.Image 对象
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
def get_prediction(image_bytes):
# 数据预处理-图片特征提取
tensor = transform_image(image_bytes=image_bytes)
# 模型预测
outputs = model.forward(tensor)
# 输出结果可能性最大的一个数值
_, y_hat = outputs.max(1)
# tensor 转为一个数值类型数据
predicted_idx = str(y_hat.item())
# 获取名称和index
return imagenet_class_index[predicted_idx]
@app.route('/')
def hello():
return "Hello World!"
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
* Serving Flask app "__main__" (lazy loading)
* Environment: production
WARNING: Do not use the development server in a production environment.
Use a production WSGI server instead.
* Debug mode: off
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
上述代码:app.py ,可以命令行执行
1. 启动服务
$ FLASK_ENV=development FLASK_APP=app.py flask run
* Serving Flask app "app.py" (lazy loading)
* Environment: development
* Debug mode: on
* Restarting with stat
* Debugger is active!
* Debugger PIN: 276-234-659
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
2. 在线预测
curl -X POST -F file=@cat_pic.jpeg http://127.0.0.1:5000/predict
{
"class_id": "n02127052",
"class_name": "lynx"
}
[1] 基于Flask 部署Pytorch模型
https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html?highlight=serving
[2] Pytorch 提供基于ImageNet预训练模型
https://pytorch.org/docs/stable/torchvision/models.html
[3] Pytorch 模型保持和加载
https://pytorch.org/tutorials/beginner/saving_loading_models.html