#! /bin/python
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
import cv2
import onnx
import torch
import numpy as np
import onnxruntime
from PIL import Image
import torchvision.transforms as trans
warnings.filterwarnings("ignore")
class ONNXModel(object):
def __init__(self, onnx_path):
"""
:param onnx_path:
"""
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
self.input_name = self.get_input_name(self.onnx_session)
self.output_name = self.get_output_name(self.onnx_session)
print("input_name:{}".format(self.input_name))
print("output_name:{}".format(self.output_name))
def get_output_name(self, onnx_session):
"""
output_name = onnx_session.get_outputs()[0].name
:param onnx_session:
:return:
"""
output_name = []
for node in onnx_session.get_outputs():
output_name.append(node.name)
return output_name
def get_input_name(self, onnx_session):
"""
input_name = onnx_session.get_inputs()[0].name
:param onnx_session:
:return:
"""
input_name = []
for node in onnx_session.get_inputs():
input_name.append(node.name)
return input_name
def get_input_feed(self, input_name, image_tensor):
"""
input_feed={self.input_name: image_tensor}
:param input_name:
:param image_tensor:
:return:
"""
input_feed = {}
for name in input_name:
input_feed[name] = image_tensor
return input_feed
def forward(self, image_tensor):
'''
image_tensor = image.transpose(2, 0, 1)
image_tensor = image_tensor[np.newaxis, :]
onnx_session.run([output_name], {input_name: x})
:param image_tensor:
:return:
'''
# 输入数据的类型必须与模型一致,以下三种写法都是可以的
# scores, boxes = self.onnx_session.run(None, {self.input_name: image_tensor})
# scores, boxes = self.onnx_session.run(self.output_name, input_feed={self.input_name: image_tensor})
input_feed = self.get_input_feed(self.input_name, image_tensor)
output = self.onnx_session.run(self.output_name, input_feed=input_feed)
# output = np.array(output).reshape(-1, 2)
output = np.array(output)
return output
# Demo
resizer = trans.Compose([
trans.Resize((224)),
trans.ToTensor(),
trans.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
label_idx = list(np.arange(0, 7))
worker = ONNXModel("./xception.onnx")
img = cv2.imread("./sample.jpg")
img_in = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR ==> RGB
# cv2.imwrite("tmp.jpg", img_in)
img_in = resizer(Image.fromarray(img_in)) # 实现array到image的转换
img_in = torch.unsqueeze(img_in, 0).numpy() #
print(img_in.shape)
# img_in = np.transpose(img_in, (2, 0, 1)).astype(np.float32)
# img_in = np.expand_dims(img_in, axis=0)
# img_in /= 255.0
output = worker.forward(img_in)
output=torch.from_numpy(output) # numpy==>tensor
output=output.relu()[0,:, label_idx, ...]
print("output:",output)
output = output.softmax(1)
print("softmax:", output)
prob_vector = output.detach().cpu().numpy().tolist()
prob_vector = np.round(prob_vector, 3)
print("prob_vector:",prob_vector[0])