4.1 构建onnx结构模型-Reshape

前言

构建onnx方式通常有两种:
1、通过代码转换成onnx结构,比如pytorch —> onnx
2、通过onnx 自定义结点,图,生成onnx结构

本文主要是简单学习和使用两种不同onnx结构,
下面以reshape 结点进行分析

方式

方法一:pytorch --> onnx

固定shape
import torch
 
class JustReshape(torch.nn.Module):
    def __init__(self):
        super(JustReshape, self).__init__()
 
    def forward(self, x):
        # x = x.view((x.shape[3], x.shape[1], x.shape[0], x.shape[2]))
        x= x.reshape(x.shape[3], x.shape[1], x.shape[0], x.shape[2])
 
        return x 
 
net = JustReshape()
model_name = 'just_reshape.onnx'#保存ONNX的文件名字
dummy_input = torch.randn(1, 31, 42, 5)
torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output'])

结果如图所示:
4.1 构建onnx结构模型-Reshape_第1张图片

动态shape

将第一维度设置为动态shape

# 只需要在这里对应位置修改即可
torch.onnx.export(net, dummy_input, model_name, 
                  input_names=['input'], 
                  output_names=['output'],
                  dynamic_axes={'input': {0: 'batch_size'},
                                'output': {0: 'batch_size'}})

# 可以将得到的模型,进一步进行简化处理
onnxsim 方式

方法二: onnx

import onnx 
from onnx import TensorProto, helper, numpy_helper

def run():
    print("run start....\n")

    reshape = helper.make_node(
        "Reshape",
        name="Reshape_0",
        inputs=["input", "shape"],
        outputs=["output"],
    )
    initializer = [ 
        helper.make_tensor("shape", TensorProto.INT64, [4], [5,31,1,42])
]
    graph = helper.make_graph(
        nodes=[reshape],
        name="test_graph",
        inputs=[helper.make_tensor_value_info(
            "input", TensorProto.FLOAT, [1,31,42,5]
        )],
        outputs=[helper.make_tensor_value_info(
            "output",TensorProto.FLOAT, [5,31,1,42]
        )],
        initializer=initializer,
    )

    op = onnx.OperatorSetIdProto()
    op.version = 11
    model = helper.make_model(graph, opset_imports=[op])
    print("run done....\n")
    return model

if __name__ == "__main__":
    model = run()
    onnx.save(model, "./test_reshape.onnx")

运行onnx

import onnx
import onnxruntime
import numpy as np


# 检查onnx计算图
def check_onnx(mdoel):
    onnx.checker.check_model(model)
    # print(onnx.helper.printable_graph(model.graph))

def run(model):
    print(f'run start....\n')
    session = onnxruntime.InferenceSession(model,providers=['CPUExecutionProvider'])
    input_name1 = session.get_inputs()[0].name  
    input_data1= np.random.randn(24,31,42,5).astype(np.float32)
    print(f'input_data1 shape:{input_data1.shape}\n')

    output_name1 = session.get_outputs()[0].name

    pred_onx = session.run(
    [output_name1], {input_name1: input_data1})[0]

    print(f'pred_onx shape:{pred_onx.shape} \n')

    print(f'run end....\n')


if __name__ == '__main__':
    path = "./reshape_dynamic_sim.onnx"
    model = onnx.load("./reshape_dynamic_sim.onnx")
    check_onnx(model)
    run(path)
    

你可能感兴趣的:(模型推理,性能优化,onnx)