解析ONNX(含C++和Python版本)

        解析ONNX,目前在网络上有许多的Python版本的解析代码,而C++目前没有开源的版本,因此在此进行公开一版C++解析ONNX数据的代码。 

        首先要在onnx/onnx at main · onnx/onnx · GitHub链接下载onnx.proto,安装protobufProtobuf安装教程和protoc编译器。

        首先要对onnx.proto进行编译,编译使用的是protoc编译器,编译命令如下:

       

C++版本 

         我使用的版本为3.21.2,在cmd命令行中编译命令如下:

      

        编译完成后会生成onnx.pb.h、onnx.pb.cc两个文件,在新建的VS project中include包含.h的头文件,可以使用其中的函数。下述的代码为解析onnx的详细代码

#include 
#include 
#include "onnx.pb.h"


void print_dim(const ::onnx::TensorShapeProto_Dimension &dim)
{
	switch (dim.value_case())
	{
	case onnx::TensorShapeProto_Dimension::ValueCase::kDimParam:
		std::cout << dim.dim_param();
		break;
	case onnx::TensorShapeProto_Dimension::ValueCase::kDimValue:
		std::cout << dim.dim_value();
		break;
	default:
		assert(false && "should never happen");
	}
}
void print_io_info(const ::google::protobuf::RepeatedPtrField<::onnx::ValueInfoProto> &info)
{
	for (auto input_data : info)
	{
		auto shape = input_data.type().tensor_type().shape();
		std::cout << "  " << input_data.name() << ":";
		std::cout << "[";
		if (shape.dim_size() != 0)
		{
			int size = shape.dim_size();
			for (int i = 0; i < size - 1; ++i)
			{
				print_dim(shape.dim(i));
				std::cout << ",";
			}
			print_dim(shape.dim(size - 1));
		}
		std::cout << "]\n";
	}
}

float from_le_bytes(const unsigned char* bytes)
{
	
	return bytes[0];
}
void print_initializer_info(const ::google::protobuf::RepeatedPtrField<::onnx::TensorProto>& info)
{
	for (auto input_data : info)
	{
		auto data_type = input_data.data_type();
		auto dims = input_data.dims();
		std::cout << "shapes: ";
		for (auto dim : dims)
			std::cout << dim<< " ";
		std::cout << std::endl;
		auto raw_data = input_data.raw_data(); // weight
		float *data_r = (float*)raw_data.c_str(); // raw_data 读取
		int k = raw_data.size() / 4; //float 是4个字节
		int i = 0;
		std::vector weight;
		while (i < k)
		{
			std::cout << *data_r << " "; //print weight
		    data_r++;
			i++;
		}
		
		auto tile = input_data.xb_number(0);
		//float *y = reinterpret_cast(&raw_data)(4);
		std::cout << raw_data.size() << std::endl;
		//auto shape = input_data.type().tensor_type().shape();
		std::cout << "  " << input_data.name() << "\n tile: " << tile << ":";
		std::cout << "[";

		std::cout << "]\n";
	}
}

void print_node_info(const ::google::protobuf::RepeatedPtrField<::onnx::NodeProto>& info)
{
	for (auto input_data : info)
	{
		auto op_type = input_data.op_type();
		// AttributeProto
		auto shape = input_data.attribute();
		std::cout << op_type <<" " << input_data.name() << ":";
		std::cout << std::endl << "Inputs:";
		for (auto inp : input_data.input())
			std::cout << inp << " ";
		std::cout << std::endl << "Outputs:";
		for (auto outp : input_data.output())
			std::cout << outp << " ";
		std::cout << std::endl << "[";
		// Print Attribute
		for (auto y : shape)
		{
			std::cout << y.name()<<": ";
			for (auto t : y.ints())
				std::cout << t << " ";
		}
		std::cout << "]\n";
	}
}

int main(void)
{

	//消息解析
	onnx::ModelProto out_msg;{
		std::fstream input("person.onnx", std::ios::in | std::ios::binary);
		if (!out_msg.ParseFromIstream(&input)) {
		  std::cerr << "failed to parse" << std::endl;
		  return -1;
		}
		std::cout << 1 << std::endl;
		std::cout << out_msg.graph().node_size() << std::endl;
	}
	onnx::ModelProto model;
	std::ifstream input("person.onnx", std::ios::ate | std::ios::binary);
	// get current position in file
	std::streamsize size = input.tellg();
	// move to start of file
	input.seekg(0, std::ios::beg);
	// read raw data
	std::vector buffer(size);
	input.read(buffer.data(), size); 
	model.ParseFromArray(buffer.data(), size); // parse protobuf
	auto graph = model.graph();
	std::cout << graph.initializer_size() << std::endl;
	std::cout << "graph inputs:\n";
	print_io_info(graph.input());
	std::cout << "graph outputs:\n";
	print_io_info(graph.output());
	std::cout << "graph initializer:\n";
	print_initializer_info(graph.initializer());
	std::cout << "graph node:\n";
	print_node_info(graph.node());
	
	return 0;
}

Python版本:

        同样使用Protoc编译器进行编译,编译命令如下:

        编译完成之后,会生成onnx_pb2.py,该文件即为类似于C++中的.h文件,可直接在代码中进行解析,代码如下:

# -*- coding: utf-8 -*-
# @Time : 2022-07-14 16:18
# @Author : ZhangTong
# @Email : [email protected]
# @File : onnx_test_reading.py

from onnx_new import onnx_pb2
import numpy as np


def onnx_datatype_to_npType(data_type):  # 不同的type对应不同的类别
    if data_type == 1:      # true
        return np.float32
    elif data_type == 2:    # true
        return np.uint8
    elif data_type == 3:    # true
        return np.int8
    elif data_type == 4:
        return np.uint8
    elif data_type == 5:
        return np.uint8
    elif data_type == 6:   # true
        return np.int32
    elif data_type == 7:   # true
        return np.int64
    elif data_type == 8:
        return np.uint8
    elif data_type == 9:
        return np.bool8
    else:
        return np.float32



def read_test():
    model = onnx_pb2.ModelProto()
    onnx_file = "./model_onnx/Demo_144k_CNN.onnx"
    #onnx_file = "./tiny-yolov3-11.onnx"

    try:
        f = open(onnx_file, "rb")
        model.ParseFromString(f.read())
        f1 = open("./model_onnx/person.onnx", 'wb')
        f1.write(model.SerializeToString())
        f1.close()
        graph = model.graph
        nodes = graph.node
        inputs = graph.input
        outputs = graph.output
        initializer = model.graph.initializer
        # print(model.graph.node)
        print('initializer:', len(initializer))
        print('nodes:', len(nodes))
        for i in range(len(initializer)):
            print(initializer[i].name, initializer[i].data_type)
            # if initializer[i].name.split('_')[-1] == 'quantized':
            if True:
                raw_data = np.frombuffer(initializer[i].raw_data,
                                         dtype=onnx_datatype_to_npType(initializer[i].data_type))
                print(np.array(raw_data, dtype=np.int8))
            print('int32_data:', initializer[i].int32_data)
            print('int32_data:', np.frombuffer(initializer[i].int32_data, dtype=np.int32))
            print('float_data:', np.frombuffer(initializer[i].float_data, dtype=np.float32))
            print(np.frombuffer(initializer[i].raw_data, dtype=onnx_datatype_to_npType(initializer[i].data_type)))
            raw_data = np.frombuffer(initializer[i].raw_data, dtype=onnx_datatype_to_npType(initializer[i].data_type))
            print(raw_data)
        data = open("mnist-12-int8.node.txt", 'w', encoding="utf-8")
        print(model.graph.node, file=data)
        f.close()
    except IOError:
        print(onnx_file + ": Could not open file.  Creating a new one.")

if __name__ == "__main__":
    read_test()

        总结一下,分析过一些onnx的模型,大部分的tensorproto都是以initializer的形式存储在onnx中的,有一些权重是存储在float32中的,有一些数据是以raw_data的形式序列化存储的,所以解析的时候最好对比一下,一般来说,一个ONNX模型的存储形式是一致的,至此更新完了C++版本和Python版本解析ONNX的代码,多赞呀!!!!

你可能感兴趣的:(C/C++,边缘计算,c++,开发语言)