首先在detr项目目录下创建onnx文件夹,用于存放detr的pth文件,后续导出的onnx文件也存放在此。
在detr项目目录下创建export_onnx.py文件,将下面代码拷贝之后直接运行即可导出detr.onnx模型,onnx模型存放到onnx文件夹下。
import io
import argparse
import onnx
import onnxruntime
import torch
from hubconf import detr_resnet50
class ONNXExporter:
@classmethod
def setUpClass(cls):
torch.manual_seed(123)
def run_model(self, model, onnx_path, inputs_list, tolerate_small_mismatch=False,
do_constant_folding=True,
output_names=None, input_names=None):
model.eval()
onnx_io = io.BytesIO()
onnx_path = onnx_path
torch.onnx.export(model, inputs_list[0], onnx_io,
input_names=input_names, output_names=output_names, export_params=True, training=False,
opset_version=12,do_constant_folding=do_constant_folding)
torch.onnx.export(model, inputs_list[0], onnx_path,
input_names=input_names, output_names=output_names, export_params=True, training=False,
opset_version=12,do_constant_folding=do_constant_folding)
print(f"[INFO] ONNX model export success! save path: {onnx_path}")
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
with torch.no_grad():
if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list):
# test_inputs = (nested_tensor_from_tensor_list(test_inputs),)
test_inputs = (test_inputs,)
test_ouputs = model(*test_inputs)
if isinstance(test_ouputs, torch.Tensor):
test_ouputs = (test_ouputs,)
self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)
def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
inputs, _ = torch.jit._flatten(inputs)
outputs, _ = torch.jit._flatten(outputs)
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()
inputs = list(map(to_numpy, inputs))
outputs = list(map(to_numpy, outputs))
ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
# compute onnxruntime output prediction
ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
ort_outs = ort_session.run(None, ort_inputs)
for i in range(0, len(outputs)):
try:
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
except AssertionError as error:
if tolerate_small_mismatch:
print(error)
else:
raise
@staticmethod
def check_onnx(onnx_path):
model = onnx.load(onnx_path)
onnx.checker.check_model(model)
print(f"[INFO] ONNX model: {onnx_path} check success!")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DETR Model to ONNX Model')
# detr pth 模型存放的路径
parser.add_argument('--model_dir', type=str, default='onnx/detr-r50-e632da11.pth',
help='DETR Pytorch Model Saved Dir')
parser.add_argument('--check', default=True, action="store_true", help='Check Your ONNX Model')
# pth转换onnx后存放的路径
parser.add_argument('--onnx_dir', type=str, default="onnx/detr.onnx", help="Check ONNX Model's dir")
parser.add_argument('--batch_size', type=int, default=1, help="Batch Size")
args = parser.parse_args()
# load torch model
detr = detr_resnet50(pretrained=False, num_classes=90 + 1).eval() # max label index add 1
# state_dict = torch.load(args.model_dir, map_location='cuda') # model path
state_dict = torch.load(args.model_dir, map_location='cpu') # model path
detr.load_state_dict(state_dict["model"])
# dummy input
dummy_image = [torch.ones(args.batch_size, 3, 800, 800)]
# to onnx
onnx_export = ONNXExporter()
onnx_export.run_model(detr, args.onnx_dir, dummy_image, input_names=['inputs'],
output_names=["pred_logits", "pred_boxes"], tolerate_small_mismatch=True)
# check onnx model
if args.check:
ONNXExporter.check_onnx(args.onnx_dir)
导出的时候可能会提示警告:
无视就好,稍等一两分钟就可以完成onnx的导出。
导出后,在同级目录下创建inference_onnx.py文件,使用刚才导出的onnx模型进行预测。
import cv2
from PIL import Image
import numpy as np
import os
import random
try:
import onnxruntime
except ImportError:
onnxruntime = None
import torch
import torchvision.transforms as T
torch.set_grad_enabled(False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = T.Compose([
T.Resize((800, 800)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def box_cxcywh_to_xyxy(x):
x = torch.from_numpy(x)
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b.cpu().numpy()
b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32)
return b
def plot_one_box(x, img, color=None, label=None, line_thickness=1):
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label:
tf = max(tl - 1, 1) # font thickness
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
CLASSES = [
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
def plot_result(pil_img, prob, boxes, save_name=None, imshow=False, imwrite=False):
cv2Image = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes):
cl = p.argmax()
label_text = '{} {}%'.format(CLASSES[cl], round(p[cl] * 100, 2))
plot_one_box((xmin, ymin, xmax, ymax), cv2Image, label=label_text)
if imshow:
cv2.imshow('detect', cv2Image)
cv2.waitKey(0)
if imwrite:
if not os.path.exists("onnx/result"):
os.makedirs('onnx/result')
cv2.imwrite('onnx/result/{}'.format(save_name), cv2Image)
def detect_onnx(ort_session, im, prob_threshold=0.7):
img = transform(im).unsqueeze(0).cpu().numpy()
ort_inputs = {"inputs": img}
scores, boxs = ort_session.run(None, ort_inputs)
probas = torch.from_numpy(np.array(scores)).softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > prob_threshold
probas = probas.cpu().detach().numpy()
keep = keep.cpu().detach().numpy()
bboxes_scaled = rescale_bboxes(boxs[0, keep], im.size)
return probas[keep], bboxes_scaled
if __name__ == "__main__":
onnx_path = "onnx/detr.onnx"
ort_session = onnxruntime.InferenceSession(onnx_path)
files = os.listdir("onnx/images")
for file in files:
img_path = os.path.join("onnx/images", file)
im = Image.open(img_path)
scores, boxes = detect_onnx(ort_session, im)
plot_result(im, scores, boxes, save_name=file, imshow=False, imwrite=True)
预测结果:
直接用pth进行推理的可以看: DETR推理代码_athrunsunny的博客-CSDN博客