pytorch模型部署

Pytorch使用torchserve部署模型比较方便和简单,也方便管理。但是由于内网服务器系统的原因,无法使用torchserve。所以选择flask框架写webapi的方式,来调用模型。

1.这里首先将模型保存未onnx格式,然后使用onnx运行时调用。

import json
import re
import logging
import cv2
import torchvision.transforms as T
import numpy as np
import torch
import os
import onnxruntime
from flask import Flask,jsonify,abort,request
app = Flask(__name__)
app.config.update(RESTIFUL_JSON=dict(ensure_ascii=False))
imageDir = "E:\\data\\test\\"
ortSession = onnxruntime.InferenceSession("./resnet34.onnx")
softmax = torch.nn.Softmax()
logging.basicConfig(level=logging.DEBUG,
                    format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
                    datefmt='%a, %d %b %Y %H:%M:%S',
                    filename='./record.log',
                    filemode='w')
with open("./index.json", "r", encoding="utf-8") as f:
    dictIdImage = json.load(f)
@app.route("/")
def index():
    return "image classify service resnet34"
@app.route("/help")
def help():
    return '''
    

python接口调用

import requests

data = {"imagename":"test.jpg"}

r = requests.post("http://xx.xx.xx.xx:5001/classify", data=data).json()

''' @app.route("/classify",methods = ["POST"]) def classifyImage(): imageName = request.form["imagename"] imagePath = os.path.join(imageDir, imageName) remoteIp = request.remote_addr logging.info("ip {} post image {}".format(remoteIp, imagePath)) if re.search(r".jpg", imageName): flag, img = image_process(imagePath) if not flag: logging.warning("image not found") return jsonify({"msg":"image not found"}),400 output = ortSession.run(None, {"input":img.numpy()}) pred = softmax(torch.from_numpy(output[0][0])).numpy() id = np.argmax(pred) prob = str(np.round(np.max(pred), 3)) label = dictIdImage[str(id)] logging.info("imagename: {} predict:{} probility: {}".format(imageName, label, prob)) return jsonify({"msg":"ok","result":{"label":label,"prob":prob}}), 200 else: logging.warning("illegal image path") return jsonify({"msg":"illegal image path"}),400 def image_process(image_path): try: image = cv2.imread(image_path) print(image) if image is None: return False, "" except: return False,"" image = resize_image(image, [640, 450]) transform = T.Compose([ T.ToPILImage(), T.ToTensor(), T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) image = transform(image) image = image.unsqueeze(0) return True,image def resize_image(img0, shape, color=(127.5, 127.5, 127.5)): shape0 = img0.shape[0:2] ratio = min(float(shape[0]) / shape0[1], float(shape[1] / shape0[0])) new_shape = (round(shape0[1] * ratio), round(shape0[0] * ratio)) img = cv2.resize(img0, new_shape, interpolation=cv2.INTER_AREA) dw, dh = (shape[0] - new_shape[0]) / 2, (shape[1] - new_shape[1]) / 2 top, bottom = round(dh - 0.1), round(dh + 0.1) lef, right = round(dw - 0.1), round(dw + 0.1) img = cv2.copyMakeBorder(img, top, bottom, lef, right, cv2.BORDER_CONSTANT, color) img = np.ascontiguousarray(img).transpose((2, 0, 1)) img = torch.from_numpy(img) return img if __name__ == "__main__": app.run(host="xx.xx.xx.xx", port=5001, debug=True)

2.访问

import requests
data = {"imagename":"test.jpg"}
r = requests.post("http://xx.xx.xx.xx:5001/classify", data=data).json()
print(r)

你可能感兴趣的:(python,pytorch,flask)