【onnx】——onnx模型计算图解析

onnx
code:

1 onnx简介

2 模型解析

因为onnx是保存的模型的计算图和模型的预训练参数,这里先解析onnx的计算图,对整个计算流程有一个认识。

计算图就是多个 op 的组合,每个 op 都有输入,输出,然后将所有的 op 结合起来,形成一个 graph

2.1 onnx protobuf定义解析

整个定义是ModelProto -> GraphProto - > NodeProto 主要就是这三个部分

最外层是ModelProto,记录一些模型信息:ir版本,来自pytorch/tensorflow,… , 和 GraphProto

message ModelProto {
  // The version of the IR this model targets. See Version enum above.
  // This field MUST be present.
  int64 ir_version = 1;

  // The OperatorSets this model relies on.
  // All ModelProtos MUST have at least one entry that
  // specifies which version of the ONNX OperatorSet is
  // being imported.
  //
  // All nodes in the ModelProto's graph will bind against the operator
  // with the same-domain/same-op_type operator with the HIGHEST version
  // in the referenced operator sets.
  repeated OperatorSetIdProto opset_import = 8;

  // The name of the framework or tool used to generate this model.
  // This field SHOULD be present to indicate which implementation/tool/framework
  // emitted the model.
  string producer_name = 2;

  // The version of the framework or tool used to generate this model.
  // This field SHOULD be present to indicate which implementation/tool/framework
  // emitted the model.
  string producer_version = 3;

  // Domain name of the model.
  // We use reverse domain names as name space indicators. For example:
  // `com.facebook.fair` or `com.microsoft.cognitiveservices`
  //
  // Together with `model_version` and GraphProto.name, this forms the unique identity of
  // the graph.
  string domain = 4;

  // The version of the graph encoded. See Version enum below.
  int64 model_version = 5;

  // A human-readable documentation for this model. Markdown is allowed.
  string doc_string = 6;

  // The parameterized graph that is evaluated to execute the model.
  GraphProto graph = 7;

  // Named metadata values; keys should be distinct.
  repeated StringStringEntryProto metadata_props = 14;
};

GraphProto才是核心,里面主要包含:1. TensorProto initializer保存const tensor + 预训练的参数。2. NodeProto node 保存每个op 输入,输出 tensor 名字。

message GraphProto {
  // The nodes in the graph, sorted topologically.
  repeated NodeProto node = 1;

  // The name of the graph.
  string name = 2;   // namespace Graph

  // A list of named tensor values, used to specify constant inputs of the graph.
  // Each TensorProto entry must have a distinct name (within the list) that
  // also appears in the input list.
  repeated TensorProto initializer = 5;

  // A human-readable documentation for this graph. Markdown is allowed.
  string doc_string = 10;

  // The inputs and outputs of the graph.
  repeated ValueInfoProto input = 11;
  repeated ValueInfoProto output = 12;

  // Information for the values in the graph. The ValueInfoProto.name's
  // must be distinct. It is optional for a value to appear in value_info list.
  repeated ValueInfoProto value_info = 13;

  // DO NOT USE the following fields, they were deprecated from earlier versions.
  // repeated string input = 3;
  // repeated string output = 4;
  // optional int64 ir_version = 6;
  // optional int64 producer_version = 7;
  // optional string producer_tag = 8;
  // optional string domain = 9;
}

NodeProto

message NodeProto {
  repeated string input = 1;    // namespace Value
  repeated string output = 2;   // namespace Value

  // An optional identifier for this node in a graph.
  // This field MAY be absent in ths version of the IR.
  string name = 3;     // namespace Node

  // The symbolic identifier of the Operator to execute.
  string op_type = 4;  // namespace Operator
  // The domain of the OperatorSet that specifies the operator named by op_type.
  string domain = 7;   // namespace Domain

  // Additional named attributes.
  repeated AttributeProto attribute = 5;

  // A human-readable documentation for this node. Markdown is allowed.
  string doc_string = 6;
}

所以整个计算图的node的输入,来自于node.input,node的输出记录在node.output,但是有些node的输入为const tensor,在graph.initializer中。

总结起来就是:[inputs - (outputs + graph.out)] in initializer

python解析代码

import onnx

def onnx_parser(onnx_path):
    onnx_model = onnx.load(onnx_path)
    graph = onnx_model.graph

    onnx_initial = []
    for init in graph.initializer:
        onnx_initial.append(init.name)

    inputs = []
    outputs = []

    for node in graph.node:
        # input
        for input in node.input:
            inputs.append(input)

        # output
        for output in node.output:
            outputs.append(output)

    print("len inputs: ", len(inputs))
    print("len outputs: ", len(outputs))

    print("len inputs: ", len(set(inputs)))
    print("len outputs: ", len(set(outputs)))

    union = set(inputs) & set(outputs)
    sub1 = set(inputs) - set(outputs)           # inputs = outputs + initial + graph.out
    for name in sub1:
        if name not in onnx_initial:
            print("{} not in onnx initial tensor".format(name))
            # assert name == graph.input

    sub2 = set(outputs) - set(inputs)       # 等于graph.out
    # assert list(sub2)[0] == graph.output

    t = 1
    



if __name__ == '__main__':
    onnx_path = 'mqbench_qmodel_for_tengine.onnx'
    ret = onnx_parser(onnx_path)

你可能感兴趣的:(嵌入式AI,人工智能,深度学习,tensorflow,onnx,onnx模型理解)