Ultralytics YOLOv8.0.225 的onnx导出

Ultralytics YOLOv8.0.225 的onnx导出代码:

import argparse
import os
import torch
from onnxsim import simplify
from ultralytics.nn import SegmentationModel
from ultralytics.nn.modules import C2f
from ultralytics.nn.tasks import attempt_load_weights

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, default='../yolov8n-cls.pt', help='weights path')
    parser.add_argument('--official_weights_onnx', type=str, default="official_weights_onnx",
                        help='official_weights_onnx')
    parser.add_argument('--img_size', nargs='+', type=int, default=[640, 640], help='image size')
    parser.add_argument('--batch-size', type=int, default=1, help='batch size')
    opt = parser.parse_args()
    opt.img_size *= 2 if len(opt.img_size) == 1 else 1  # expand
    save_p = opt.official_weights_onnx
    if not os.path.exists(save_p):
        os.makedirs(save_p)
    img = torch.zeros((opt.batch_size, 3, *opt.img_size))
    model = attempt_load_weights(opt.weights,
                                 device=torch.device('cpu'),
                                 inplace=True,
                                 fuse=True)
    model.model[-1].export = True  # set Detect() layer export=True
    for k, m in model.named_modules():
        if isinstance(m, C2f):
            # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
            m.forward = m.forward_split
        else:
            m.dynamic = False
            m.export = True
            m.format = "onnx"

    model.eval()
    model.model[-1].export = True  # set Detect() layer export=True
    y = model(img)  # dry run
    save_p = "official_weights_onnx"
    if not os.path.exists(save_p):
        os.makedirs(save_p)
    # ONNX export
    try:
        import onnx

        print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
        f = save_p + "/" + "BS_" + str(opt.batch_size) + "_" + str(list(opt.img_size)[0]) + "_" + opt.weights.replace("pt","onnx").split("/")[-1]  # filename
        model.fuse()  # only for ONNX

        print("===========  onnx =========== ")
        input_names = ["data"]

        torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
                          output_names=['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0'])

        # Checks
        onnx_model = onnx.load(f)  # load onnx model
        onnx.checker.check_model(onnx_model)  # check onnx model
        print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model
        print('ONNX export success, saved as %s' % f)
        onnx_model = onnx.load(f)  # load onnx model
        model_simp, check = simplify(onnx_model)
        assert check, "Simplified ONNX model could not be validated"
        SAVE_INFO_SIM = f.split("BS")
        SAVE_SIM_NAME = SAVE_INFO_SIM[-2] + "sim_BS" + SAVE_INFO_SIM[-1]
        onnx.save(model_simp, SAVE_SIM_NAME)
        print('finished exporting Simplified onnx as ', SAVE_SIM_NAME)
    except Exception as e:
        print('ONNX export failure: %s' % e)

你可能感兴趣的:(YOLO,pytorch,深度学习)