PyTorch模型转ONNX格式

前言

在使用PyTorch进行网络训练得到.pth模型文件后,我们可能会做一些模型部署和加速的工作。这里一般会涉及到将PyTorch模型转为ONNX模型的过程。
在这里插入图片描述
PyTorch自带了ONNX转换方法(torch.onnx.export),可以很方便的将一些仅包含通用算子的网络的PyTorch模型转为ONNX格式。

转换步骤

准备模型文件

将PyTorch模型文件准备好,放在’./weights/torch.pth’路径下。

编写转换代码

__author__ = 'TracelessLe'

import torch

TORCH_WEIGHT_PATH = './weights/torch.pth'
ONNX_MODEL_PATH = 'net_bs8_v1.onnx'

def get_numpy_data():
	batch_size = 8
    img_input = np.ones((batch_size, 3, 128, 128), dtype=np.float32)
    return img_input

def get_torch_model():
    # Load Network Here
    pass 

def torch2onnx(img_input, onnx_model_path, device_id=0):
    torch_model = get_torch_model()  # Network define
    device = 'cpu' if device_id < 0 else f'cuda:{device_id}'
    torch_model.to(device)
    torch_weights = torch.load(TORCH_WEIGHT_PATH)
    torch_model.load_state_dict(torch_weights)
    torch_model.eval()
    dummy_img = torch.Tensor(img_input).to(device)
    torch.onnx.export(
        torch_model,
        (dummy_img),
        onnx_model_path,
        input_names=['input'],
        output_names=['output'],
        export_params=True,
        verbose=True,
        do_constant_folding=False,  # or True
        opset_version=11
    )
    print("Generate ONNX file over!")

if __name__ == "__main__":
    img_input = get_numpy_data()
    torch2onnx(img_input, ONNX_MODEL_PATH)

简化ONNX模型结构(可选)

生成的ONNX结构可能还有简化的空间,可以使用onnx-simplifier工具进一步优化。

__author__ = 'TracelessLe'

import onnx
from onnxsim import simplify

ONNX_MODEL_PATH = 'net_bs8_v1.onnx'
ONNX_SIM_MODEL_PATH = 'net_bs8_v1_simple.onnx'

if __name__ == "__main__":
    onnx_model = onnx.load(ONNX_MODEL_PATH)
    onnx_sim_model, check = simplify(onnx_model)
    assert check, "Simplified ONNX model could not be validated"
    onnx.save(onnx_sim_model, ONNX_SIM_MODEL_PATH)
    print('ONNX file simplified!')

对比结果

__author__ = 'TracelessLe'

import time
import torch
import onnxruntime
import numpy as np

def test_torch(img_input, device_id=0, loop=100):
    torch_model = get_torch_model()
    device = 'cpu' if device_id < 0 else f'cuda:{device_id}'
    torch_model.to(device)
    torch_weights = torch.load(TORCH_WEIGHT_PATH)
    torch_model.load_state_dict(torch_weights)
    torch_model.eval()
    dummy_img = torch.Tensor(img_input).to(device)
    batch_size = 8
    with torch.no_grad():
        out_img = torch_model(dummy_img)
    time1 = time.time()
    for i in range(loop):
        time_bs1 = time.time()
        with torch.no_grad():
            out_img = torch_model(dummy_img)
            out_img_numpy = out_img.detach().cpu().numpy()
        time_bs2 = time.time()
        time_use_pt_bs = time_bs2 - time_bs1
        print(f'PyTorch use time {time_use_pt_bs} for bs8')
    time2 = time.time()
    time_use_pt = time2-time1
    print(f'PyTorch use time {time_use_pt} for loop {loop}, FPS={loop*batch_size//time_use_pt}')
    return out_img_numpy

def test_onnx(inputs, loop=100):
    inputs = inputs.astype(np.float32)
    print(onnxruntime.get_device())
    sess = onnxruntime.InferenceSession(ONNX_SIM_MODEL_PATH)
    batch_size = 8
    time1 = time.time()
    for i in range(loop):
        time_bs1 = time.time()
        out_ort_img = sess.run(None, {sess.get_inputs()[0].name: inputs,})
        time_bs2 = time.time()
        time_use_onnx_bs = time_bs2 - time_bs1
        print(f'ONNX use time {time_use_onnx_bs} for bs8')
    time2 = time.time()
    time_use_onnx = time2-time1
    print(f'ONNX use time {time_use_onnx} for loop {loop}, FPS={loop*batch_size//time_use_onnx}')
    return out_ort_img

if __name__ == "__main__":
    img_input = get_numpy_data()
    out_ort_img = test_onnx(img_input, loop=1)[0]
    out_img_numpy = test_torch(img_input, loop=1)
    mse = np.square(np.subtract(out_ort_img, out_img_numpy)).mean()
    print('mse between pytorch and onnx result: ', mse)

模型可视化(可选)

可以使用Netron可视化PyTorch、ONNX等模型结构。
在这里插入图片描述
通过Netron打开转换得到的ONNX以及简化后的ONNX文件,可以对比查看网络结构的变化。
PyTorch模型转ONNX格式_第1张图片

其他说明

(1)do_constant_folding
在使用torch.onnx.export时可能会报错(与do_constant_folding相关),那么可以将do_constant_folding=True设为do_constant_folding=False

(2)自定义算子的转换
PyTorch转ONNX目前仅支持一些通用算子(见PyTorch Doc),自定义的算子在转出时会报错。可以通过改写一部分算子或者使用功能相近的通用算子替代。

(3)PyTorch转ONNX的方式
目前PyTorch转ONNX有两种方式,除了主流的torch.onnx.export外,还有torch.jit.trace。但是后一种还处于实验阶段,可能会遇到较多问题。

版权说明

本文为原创文章,独家发布在blog.csdn.net/TracelessLe。未经个人允许不得转载。如需帮助请email至[email protected]
PyTorch模型转ONNX格式_第2张图片

参考资料

[1] ONNX | Home
[2] torch.onnx — PyTorch 1.9.0 documentation
[3] (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 1.9.0+cu102 documentation
[4] daquexian/onnx-simplifier: Simplify your onnx model
[5] lutzroeder/netron: Visualizer for neural network, deep learning, and machine learning models

你可能感兴趣的:(#,深度学习框架,pytorch,深度学习,神经网络,ONNX)