ONNX网络模型解析

        ONNX是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的人工智能框架(如Pytorch, MXNet)可以采用相同格式存储模型数据并交互。 ONNX的规范及代码主要由微软,亚马逊 ,Facebook 和 IBM 等公司共同开发,以开放源代码的方式托管在Github上。目前官方支持加载ONNX模型并进行推理的深度学习框架有: Caffe2, PyTorch, MXNet,ML.NET,TensorRT 和 Microsoft CNTK,并且 TensorFlow 也非官方的支持ONNX。(参考维基百科)

        如需自定义网络模型结构,尝试使用自定义的结构保存网络模型,首先需要解析ONNX的net结构。ONNX使用google定义的报文格式protocol buffer,用于RPC 系统和持续数据存储系统。

        Protocol Buffers 是一种轻便高效的结构化数据存储格式,可以用于结构化数据串行化,或者说序列化。它很适合做数据存储或 RPC 数据交换格式。可用于通讯协议、数据存储等领域的语言无关、平台无关、可扩展的序列化结构数据格式。目前提供了 C++、Java、Python 三种语言的 API,本文使用了C++接口解析ONNX模型。

        Protocol buffer结构通常会定义要给proto文件,通过ONNXgithub链接下载onnx.protohttps://github.com/onnx/onnx。解析模型用到的结构主要如下:

  1. ModelProto:最高级别的结构,定义了整个网络模型结构;
  2. GraphProto: graph定义了模型的计算逻辑以及带有参数的node节点,组成一个有向图结构;
  3. NodeProto: 网络有向图的各个节点OP的结构,通常称为层,例如conv,relu层;
  4. AttributeProto:各OP的参数,通过该结构访问,例如:conv层的stride,dilation等;
  5. TensorProto: 序列化的tensor value,一般weight,bias等常量均保存为该种结构;
  6. TensorShapeProto:网络的输入shape以及constant输入tensor的维度信息均保存为该种结构;
  7. TypeProto:表示ONNX数据类型。

        具体解析流程是读取.onnx文件,获得一个model结构,通过model结构访问到graph结构,然后通过graph访问整个网络的所有node以及input,output,通过node结构可以访问到OP的参数。

下面给出解析demo:

void ReadProtoFromBinaryFile(const char* filename, google::protobuf::Message* proto)
 {
	int fd = open(filename, O_RDONLY);
	google::protobuf::io::FileInputStream* raw_input = new google::protobuf::io::FileInputStream(fd);
	google::protobuf::io::CodedInputStream* coded_input = new google::protobuf::io::CodedInputStream(raw_input);
	coded_input->SetTotalBytesLimit(INT_MAX, 536870912);
	bool success = proto->ParseFromCodedStream(coded_input);
	//bool success = proto->ParseFromZeroCopyStream(raw_input);
	delete coded_input;
	delete raw_input;
	close(fd);
	if (success != true)
	{
		exit(1);
	}
}
int main()
{
	char* ch = "inception_v3.onnx";
	onnx::ModelProto model_data;
	ReadProtoFromBinaryFile(ch, &model_data);    //读取文件保存为modelproto结构
	onnx::GraphProto graph = model_data.graph();  //访问graph结构
	int num = graph.node_size();                    //node节点个数
	int input_size = graph.input_size();            //网络输入个数,input以及各层常量输入
	for (int i = 0; i < input_size; ++i)
	{
		const std::string name = graph.input(i).name();
		onnx::TypeProto type = graph.input(i).type();
		onnx::TensorShapeProto shape = type.tensor_type().shape();//输入维度
		for (int i = 0; i < shape.dim_size(); i++)
		{
			std::cout << shape.dim(i).dim_value();
			std::cout<< std::endl;
		}
	}
	int output_size = graph.output_size();    //网络output个数
	for (int i = 0; i < num; i++)    //遍历每个node结构
	{
		const onnx::NodeProto node = graph.node(i);
		std::string node_name = node.name();
		std::cout <<"cur node name:"<< node_name << std::endl;
		const ::google::protobuf::RepeatedPtrField< ::onnx::AttributeProto> attr = node.attribute();        //每个node结构的参数信息
		const std::string type = node.op_type();
		int in_size = node.input_size();	
		int out_size = node.output_size();
	}
	return 0;
}

 

你可能感兴趣的:(深度学习)