4.9 构建onnx结构模型-Equal

前言

构建onnx方式通常有两种:
1、通过代码转换成onnx结构,比如pytorch —> onnx
2、通过onnx 自定义结点,图,生成onnx结构

本文主要是简单学习和使用两种不同onnx结构,
下面以 Equal 结点进行分析
4.9 构建onnx结构模型-Equal_第1张图片

方式

方法一:pytorch --> onnx

暂缓,主要研究方式二

方法二: onnx

# import  torch
# import torch.nn as nn

# class JustEqual(nn.Module):
#     def __init__(self):
#         super(JustEqual, self).__init__()

#     def forward(self,x):


#         return x

import onnx 
from onnx import TensorProto, helper, numpy_helper
import numpy as np

def run():
    print("run start....\n")

    equal = helper.make_node(
        "Equal",
        name="Equal_0",
        inputs=["input", "equal"],
        outputs=["output1"],
    )
    # initializer = [ 
    #     helper.make_tensor("equal", TensorProto.FLOAT, [1,1,1], np.zeros((1,1,1), dtype=np.float32))
    # ]
    initializer = [ 
        helper.make_tensor("equal", TensorProto.FLOAT, [1], np.zeros((1), dtype=np.float32))
    ]

    # initializer = [ 
    #     helper.make_tensor("equal", TensorProto.FLOAT, [], [0])
    # ]

    cast_nodel = helper.make_node(
            op_type="Cast",
            inputs=["output1"],
            outputs=["output2"],
            name="test_cast",
            to=TensorProto.FLOAT,
        )
    value_info = helper.make_tensor_value_info(
            "output1", TensorProto.BOOL, [16,1,397])

    graph = helper.make_graph(
        nodes=[equal, cast_nodel],
        name="test_graph",
        inputs=[helper.make_tensor_value_info(
            "input", TensorProto.FLOAT, [16,1,397]
        )],
        outputs=[helper.make_tensor_value_info(
            "output2",TensorProto.FLOAT, [16,1,397]
        )],
        initializer=initializer,
        value_info=[value_info],
    )

    op = onnx.OperatorSetIdProto()
    op.version = 11
    model = helper.make_model(graph, opset_imports=[op])
    model.ir_version = 8
    print("run done....\n")
    return model

if __name__ == "__main__":
    model = run()
    onnx.save(model, "./test_equal.onnx")
    # onnx.save(model, "./test_equal_ori.onnx")


run

import onnx
import onnxruntime
import numpy as np


# 检查onnx计算图
def check_onnx(mdoel):
    onnx.checker.check_model(model)
    # print(onnx.helper.printable_graph(model.graph))

def run(model):
    print(f'run start....\n')
    session = onnxruntime.InferenceSession(model,providers=['CPUExecutionProvider'])
    input_name1 = session.get_inputs()[0].name  
    input_data1= np.random.randn(16,1,397).astype(np.float32)
    print(f'input_data1 shape:{input_data1.shape}\n')

    output_name1 = session.get_outputs()[0].name

    pred_onx = session.run(
    [output_name1], {input_name1: input_data1})[0]

    print(f'pred_onx shape:{pred_onx.shape} \n')

    print(f'run end....\n')


if __name__ == '__main__':
    path = "./test_equal.onnx"
    model = onnx.load("./test_equal.onnx")
    check_onnx(model)
    run(path)

你可能感兴趣的:(模型推理,onnx,python,onnxruntime,性能优化)