本文章参考,仅作为个人笔记使用
https://www.bilibili.com/video/BV1Qv41117SR?spm_id_from=333.999.0.0&vd_source=91ca5a3d30d9b0f2a4679da8345eb0f5
要安装Flask,postman(可选)
pytorch等深度学习必备
1、准备好要部署的模型和训练好的权重
2、准备好部署在前端的web框架
3、主函数中对模型进行实例化加载权重
4、写好对输入图像进行预处理的函数然后输入到模型中
5、写好预测并得到最后结果的函数
6、使用flask进行部署,这部分是模板
使用参考中所使用的MobileNetV2,以及训练flowers数据集得到的权重
使用参考视频提供的简单的前端框架
import os
import io
import json
import torch
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request, render_template
from flask_cors import CORS #解决跨域问题,不一定必须
from model import MobileNetV2
app = Flask(__name__)
CORS(app) # 解决跨域问题
weights_path = "./MobileNetV2(flower).pth" # 模型权重路径
class_json_path = "./class_indices.json" #分类结果对应的json文件
assert os.path.exists(weights_path), "weights path does not exist..."
assert os.path.exists(class_json_path), "class json path does not exist..."
# select device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# create model
model = MobileNetV2(num_classes=5).to(device) #实例化模型
# load model weights
model.load_state_dict(torch.load(weights_path, map_location=device)) #加载权重
model.eval()
# load class info
json_file = open(class_json_path, 'rb') #读入分类结果文件
class_indict = json.load(json_file)
def transform_image(image_bytes): # 读入的是二进制文件
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes)) # 转化二进制文件
if image.mode != "RGB":
raise ValueError("input file does not RGB image...")
return my_transforms(image).unsqueeze(0).to(device)
def get_prediction(image_bytes):
try: # 确保过程没有错误
tensor = transform_image(image_bytes=image_bytes) # 输入图像进行预处理
outputs = torch.softmax(model.forward(tensor).squeeze(), dim=0) # 输出结果softmax
prediction = outputs.detach().cpu().numpy() #把预测结果传回cpu
template = "class:{:<15} probability:{:.3f}" #展现结果的模板,<15表示占据15个空格,保证对齐效果
index_pre = [(class_indict[str(index)], float(p)) for index, p in enumerate(prediction)] # 提取概率结果和索引
# sort probability
index_pre.sort(key=lambda x: x[1], reverse=True) # 根根据概率进行排序
text = [template.format(k, v) for k, v in index_pre]
return_info = {"result": text}
except Exception as e:
return_info = {"result": [str(e)]}
return return_info
@app.route("/predict", methods=["POST"])
@torch.no_grad()
def predict():
image = request.files["file"]
img_bytes = image.read()
info = get_prediction(image_bytes=img_bytes)
return jsonify(info)
@app.route("/", methods=["GET", "POST"])
def root():
return render_template("up.html") #使用准备好的前端模板
浏览器输入localhost:5000或者本机ip地址:5000