Pytorch使用torchserve部署模型比较方便和简单,也方便管理。但是由于内网服务器系统的原因,无法使用torchserve。所以选择flask框架写webapi的方式,来调用模型。
1.这里首先将模型保存未onnx格式,然后使用onnx运行时调用。
import json
import re
import logging
import cv2
import torchvision.transforms as T
import numpy as np
import torch
import os
import onnxruntime
from flask import Flask,jsonify,abort,request
app = Flask(__name__)
app.config.update(RESTIFUL_JSON=dict(ensure_ascii=False))
imageDir = "E:\\data\\test\\"
ortSession = onnxruntime.InferenceSession("./resnet34.onnx")
softmax = torch.nn.Softmax()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S',
filename='./record.log',
filemode='w')
with open("./index.json", "r", encoding="utf-8") as f:
dictIdImage = json.load(f)
@app.route("/")
def index():
return "image classify service resnet34"
@app.route("/help")
def help():
return '''
python接口调用
import requests
data = {"imagename":"test.jpg"}
r = requests.post("http://xx.xx.xx.xx:5001/classify", data=data).json()
'''
@app.route("/classify",methods = ["POST"])
def classifyImage():
imageName = request.form["imagename"]
imagePath = os.path.join(imageDir, imageName)
remoteIp = request.remote_addr
logging.info("ip {} post image {}".format(remoteIp, imagePath))
if re.search(r".jpg", imageName):
flag, img = image_process(imagePath)
if not flag:
logging.warning("image not found")
return jsonify({"msg":"image not found"}),400
output = ortSession.run(None, {"input":img.numpy()})
pred = softmax(torch.from_numpy(output[0][0])).numpy()
id = np.argmax(pred)
prob = str(np.round(np.max(pred), 3))
label = dictIdImage[str(id)]
logging.info("imagename: {} predict:{} probility: {}".format(imageName, label, prob))
return jsonify({"msg":"ok","result":{"label":label,"prob":prob}}), 200
else:
logging.warning("illegal image path")
return jsonify({"msg":"illegal image path"}),400
def image_process(image_path):
try:
image = cv2.imread(image_path)
print(image)
if image is None:
return False, ""
except:
return False,""
image = resize_image(image, [640, 450])
transform = T.Compose([
T.ToPILImage(),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
image = transform(image)
image = image.unsqueeze(0)
return True,image
def resize_image(img0, shape, color=(127.5, 127.5, 127.5)):
shape0 = img0.shape[0:2]
ratio = min(float(shape[0]) / shape0[1], float(shape[1] / shape0[0]))
new_shape = (round(shape0[1] * ratio), round(shape0[0] * ratio))
img = cv2.resize(img0, new_shape, interpolation=cv2.INTER_AREA)
dw, dh = (shape[0] - new_shape[0]) / 2, (shape[1] - new_shape[1]) / 2
top, bottom = round(dh - 0.1), round(dh + 0.1)
lef, right = round(dw - 0.1), round(dw + 0.1)
img = cv2.copyMakeBorder(img, top, bottom, lef, right, cv2.BORDER_CONSTANT, color)
img = np.ascontiguousarray(img).transpose((2, 0, 1))
img = torch.from_numpy(img)
return img
if __name__ == "__main__":
app.run(host="xx.xx.xx.xx", port=5001, debug=True)
2.访问
import requests
data = {"imagename":"test.jpg"}
r = requests.post("http://xx.xx.xx.xx:5001/classify", data=data).json()
print(r)