Ultralytics YOLOv8.0.225 的onnx导出代码:
import argparse
import os
import torch
from onnxsim import simplify
from ultralytics.nn import SegmentationModel
from ultralytics.nn.modules import C2f
from ultralytics.nn.tasks import attempt_load_weights
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='../yolov8n-cls.pt', help='weights path')
parser.add_argument('--official_weights_onnx', type=str, default="official_weights_onnx",
help='official_weights_onnx')
parser.add_argument('--img_size', nargs='+', type=int, default=[640, 640], help='image size')
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
save_p = opt.official_weights_onnx
if not os.path.exists(save_p):
os.makedirs(save_p)
img = torch.zeros((opt.batch_size, 3, *opt.img_size))
model = attempt_load_weights(opt.weights,
device=torch.device('cpu'),
inplace=True,
fuse=True)
model.model[-1].export = True # set Detect() layer export=True
for k, m in model.named_modules():
if isinstance(m, C2f):
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
m.forward = m.forward_split
else:
m.dynamic = False
m.export = True
m.format = "onnx"
model.eval()
model.model[-1].export = True # set Detect() layer export=True
y = model(img) # dry run
save_p = "official_weights_onnx"
if not os.path.exists(save_p):
os.makedirs(save_p)
# ONNX export
try:
import onnx
print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
f = save_p + "/" + "BS_" + str(opt.batch_size) + "_" + str(list(opt.img_size)[0]) + "_" + opt.weights.replace("pt","onnx").split("/")[-1] # filename
model.fuse() # only for ONNX
print("=========== onnx =========== ")
input_names = ["data"]
torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
output_names=['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0'])
# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print('ONNX export success, saved as %s' % f)
onnx_model = onnx.load(f) # load onnx model
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
SAVE_INFO_SIM = f.split("BS")
SAVE_SIM_NAME = SAVE_INFO_SIM[-2] + "sim_BS" + SAVE_INFO_SIM[-1]
onnx.save(model_simp, SAVE_SIM_NAME)
print('finished exporting Simplified onnx as ', SAVE_SIM_NAME)
except Exception as e:
print('ONNX export failure: %s' % e)