解析ONNX,目前在网络上有许多的Python版本的解析代码,而C++目前没有开源的版本,因此在此进行公开一版C++解析ONNX数据的代码。
首先要在onnx/onnx at main · onnx/onnx · GitHub链接下载onnx.proto,安装protobufProtobuf安装教程和protoc编译器。
首先要对onnx.proto进行编译,编译使用的是protoc编译器,编译命令如下:
我使用的版本为3.21.2,在cmd命令行中编译命令如下:
编译完成后会生成onnx.pb.h、onnx.pb.cc两个文件,在新建的VS project中include包含.h的头文件,可以使用其中的函数。下述的代码为解析onnx的详细代码
#include
#include
#include "onnx.pb.h"
void print_dim(const ::onnx::TensorShapeProto_Dimension &dim)
{
switch (dim.value_case())
{
case onnx::TensorShapeProto_Dimension::ValueCase::kDimParam:
std::cout << dim.dim_param();
break;
case onnx::TensorShapeProto_Dimension::ValueCase::kDimValue:
std::cout << dim.dim_value();
break;
default:
assert(false && "should never happen");
}
}
void print_io_info(const ::google::protobuf::RepeatedPtrField<::onnx::ValueInfoProto> &info)
{
for (auto input_data : info)
{
auto shape = input_data.type().tensor_type().shape();
std::cout << " " << input_data.name() << ":";
std::cout << "[";
if (shape.dim_size() != 0)
{
int size = shape.dim_size();
for (int i = 0; i < size - 1; ++i)
{
print_dim(shape.dim(i));
std::cout << ",";
}
print_dim(shape.dim(size - 1));
}
std::cout << "]\n";
}
}
float from_le_bytes(const unsigned char* bytes)
{
return bytes[0];
}
void print_initializer_info(const ::google::protobuf::RepeatedPtrField<::onnx::TensorProto>& info)
{
for (auto input_data : info)
{
auto data_type = input_data.data_type();
auto dims = input_data.dims();
std::cout << "shapes: ";
for (auto dim : dims)
std::cout << dim<< " ";
std::cout << std::endl;
auto raw_data = input_data.raw_data(); // weight
float *data_r = (float*)raw_data.c_str(); // raw_data 读取
int k = raw_data.size() / 4; //float 是4个字节
int i = 0;
std::vector weight;
while (i < k)
{
std::cout << *data_r << " "; //print weight
data_r++;
i++;
}
auto tile = input_data.xb_number(0);
//float *y = reinterpret_cast(&raw_data)(4);
std::cout << raw_data.size() << std::endl;
//auto shape = input_data.type().tensor_type().shape();
std::cout << " " << input_data.name() << "\n tile: " << tile << ":";
std::cout << "[";
std::cout << "]\n";
}
}
void print_node_info(const ::google::protobuf::RepeatedPtrField<::onnx::NodeProto>& info)
{
for (auto input_data : info)
{
auto op_type = input_data.op_type();
// AttributeProto
auto shape = input_data.attribute();
std::cout << op_type <<" " << input_data.name() << ":";
std::cout << std::endl << "Inputs:";
for (auto inp : input_data.input())
std::cout << inp << " ";
std::cout << std::endl << "Outputs:";
for (auto outp : input_data.output())
std::cout << outp << " ";
std::cout << std::endl << "[";
// Print Attribute
for (auto y : shape)
{
std::cout << y.name()<<": ";
for (auto t : y.ints())
std::cout << t << " ";
}
std::cout << "]\n";
}
}
int main(void)
{
//消息解析
onnx::ModelProto out_msg;{
std::fstream input("person.onnx", std::ios::in | std::ios::binary);
if (!out_msg.ParseFromIstream(&input)) {
std::cerr << "failed to parse" << std::endl;
return -1;
}
std::cout << 1 << std::endl;
std::cout << out_msg.graph().node_size() << std::endl;
}
onnx::ModelProto model;
std::ifstream input("person.onnx", std::ios::ate | std::ios::binary);
// get current position in file
std::streamsize size = input.tellg();
// move to start of file
input.seekg(0, std::ios::beg);
// read raw data
std::vector buffer(size);
input.read(buffer.data(), size);
model.ParseFromArray(buffer.data(), size); // parse protobuf
auto graph = model.graph();
std::cout << graph.initializer_size() << std::endl;
std::cout << "graph inputs:\n";
print_io_info(graph.input());
std::cout << "graph outputs:\n";
print_io_info(graph.output());
std::cout << "graph initializer:\n";
print_initializer_info(graph.initializer());
std::cout << "graph node:\n";
print_node_info(graph.node());
return 0;
}
同样使用Protoc编译器进行编译,编译命令如下:
编译完成之后,会生成onnx_pb2.py,该文件即为类似于C++中的.h文件,可直接在代码中进行解析,代码如下:
# -*- coding: utf-8 -*-
# @Time : 2022-07-14 16:18
# @Author : ZhangTong
# @Email : [email protected]
# @File : onnx_test_reading.py
from onnx_new import onnx_pb2
import numpy as np
def onnx_datatype_to_npType(data_type): # 不同的type对应不同的类别
if data_type == 1: # true
return np.float32
elif data_type == 2: # true
return np.uint8
elif data_type == 3: # true
return np.int8
elif data_type == 4:
return np.uint8
elif data_type == 5:
return np.uint8
elif data_type == 6: # true
return np.int32
elif data_type == 7: # true
return np.int64
elif data_type == 8:
return np.uint8
elif data_type == 9:
return np.bool8
else:
return np.float32
def read_test():
model = onnx_pb2.ModelProto()
onnx_file = "./model_onnx/Demo_144k_CNN.onnx"
#onnx_file = "./tiny-yolov3-11.onnx"
try:
f = open(onnx_file, "rb")
model.ParseFromString(f.read())
f1 = open("./model_onnx/person.onnx", 'wb')
f1.write(model.SerializeToString())
f1.close()
graph = model.graph
nodes = graph.node
inputs = graph.input
outputs = graph.output
initializer = model.graph.initializer
# print(model.graph.node)
print('initializer:', len(initializer))
print('nodes:', len(nodes))
for i in range(len(initializer)):
print(initializer[i].name, initializer[i].data_type)
# if initializer[i].name.split('_')[-1] == 'quantized':
if True:
raw_data = np.frombuffer(initializer[i].raw_data,
dtype=onnx_datatype_to_npType(initializer[i].data_type))
print(np.array(raw_data, dtype=np.int8))
print('int32_data:', initializer[i].int32_data)
print('int32_data:', np.frombuffer(initializer[i].int32_data, dtype=np.int32))
print('float_data:', np.frombuffer(initializer[i].float_data, dtype=np.float32))
print(np.frombuffer(initializer[i].raw_data, dtype=onnx_datatype_to_npType(initializer[i].data_type)))
raw_data = np.frombuffer(initializer[i].raw_data, dtype=onnx_datatype_to_npType(initializer[i].data_type))
print(raw_data)
data = open("mnist-12-int8.node.txt", 'w', encoding="utf-8")
print(model.graph.node, file=data)
f.close()
except IOError:
print(onnx_file + ": Could not open file. Creating a new one.")
if __name__ == "__main__":
read_test()
总结一下,分析过一些onnx的模型,大部分的tensorproto都是以initializer的形式存储在onnx中的,有一些权重是存储在float32中的,有一些数据是以raw_data的形式序列化存储的,所以解析的时候最好对比一下,一般来说,一个ONNX模型的存储形式是一致的,至此更新完了C++版本和Python版本解析ONNX的代码,多赞呀!!!!