TRT8系列—— pytorch 模型转 onnx

代码

Torch -> onnx 动态batch单输入多(两)输出 的代码如下:

import torch


def export():
    # load model
    model = CartoonPornModel()
    model_weights = '/export/xxxx.tar'
    model_state_dict = torch.load(model_weights, map_location='cpu')
    model.load_state_dict( model_state_dict["state_dict”])
    model.eval()

    # export pytorch to onnx
    dummy_input_1 = torch.randn(1, 3, 320, 320)
    input_names = ["images"]
    output_names = ["probs", "similarity"]
    torch.onnx.export(model, dummy_input_1, "xxxx_ori.onnx", verbose=True, opset_version=12,
                      input_names=input_names, output_names=output_names, 
                      dynamic_axes={"images": [0], "probs": [0], "similarity": [0]})

    # simplify onnx
    import onnxsim, onnx
    model_onnx = onnx.load("xxxx_ori.onnx")
    model_onnx, check = onnxsim.simplify(model_onnx)
    onnx.save(model_onnx, "xxxx_sim.onnx")

if __name__ == '__main__':
    export()

Note :

1、正如pytorch 官网所说:If model is not a torch.jit.ScriptModule nor a torch.jit.ScriptFunction, this runs model once in order to convert it to a TorchScript graph to be exported (the equivalent of torch.jit.trace()). 也就是说对于算法同学常用的将pytorch 原生训练(或推理)代码里面的model,直接调用export 的时候,其实export 会再运行一次模型,也就是对应着我们的forward 函数,所以:

1)要注意不要将我们的推理代码改为’detect’或其他名字;

2)我们想要导出的所有操作都写到forward里面(包含pytorch nn 和 function)。

2、关于input_names 和 output_names 参数,先看官网:

* input_names (list of str, default empty list) – names to assign to the input nodes of the graph, in order.

* output_names (list of str, default empty list) – names to assign to the output nodes of the graph, in order.

1)该参数可省略不写

2)一定要认识到,这两个值是赋给导出的onnx的,并不是让你去找pytorch里面定义的模型的输入、输出名。

3)建议写,因为后面使用TRT的时候还要指定输入输出那个时候指的就是这个输出输出名,然后注意如果是多个输入输出,这里是按照顺序赋值,所以还是要打印出来结果或者去源码forward里面看一下顺序。

3、动态shape,包含动态batch,图像尺寸动态,先看官网:

By default the exported model will have the shapes of all input and output tensors set to exactly match those given in args. To specify axes of tensors as dynamic (i.e. known only at run-time), set dynamic_axes to a dict with schema。并且官网也给了一个很好的例子,可以去参考:torch.onnx — PyTorch 1.12 documentation

主要是对dynamic_axes参数进行赋值(如果不写该参数,默认是定batch,此时batch 的大小取决于导出onnx时的输入的大小),这个参数是一个字典,键是2中提到的输入输出层的名字,值是一个列表或者字典,如果是列表,数字代表哪个维度为动态(比如图像常见的BCHW,动态batch 的场景,需要填入0,从0开始计数);如果为字典,字典的键位维度,值为可以为这个维度起一个名字。

4、在导出ONNX模型之前,必须调用model.eval() 来将dropout和batch normalization层设置为推理模式。并且建议在CPU操作,避免有些操作再GPU不支持。

    

参考链接

详细pytorch export: pytorch模型导出成ONNX格式:支持多参数与动态输入_superbin的博客-CSDN博客_onnx动态输入

官方文档:torch.onnx — PyTorch 1.12 documentation

节点名称转换等:导出ONNX模型 - FrameworkPTAdapter 2.0.1 PyTorch网络模型移植&训练指南 01 - 华为

模型部署入门教程:ONNX 模型的修改与调试

你可能感兴趣的:(深度学习工具链,pytorch,python,深度学习,计算机视觉)