onnx-simplifier原理探究

简介

onnx-simplifier对于经常转模型的同学来说,是一个非常方便好用的工具,它可以对pytorch、tf或者paddle转换得到的onnx模型做优化,去除很多胶水节点,以及做一些图优化,得到一个简洁明了的模型图

优化原理

整个工程的代码实现很简单,老版本是python实现,只有一个文件,主要功能就两个函数,新版本换成了cpp实现,下面以新版cpp的源码,讲下优化的原理

  • 首先是主函数,可以看到实现融合胶水节点为initializer的函数为FoldConstant
onnx::ModelProto Simplify(
    const onnx::ModelProto& model,
    std::optional<std::vector<std::string>> skip_optimizers,
    bool constant_folding, bool shape_inference, size_t tensor_size_threshold) {
  config.tensor_size_threshold = tensor_size_threshold;
  config.optimizer_passes.clear();
  // skip_optimizers == nullopt means skiping all optimizers, so
  // config.optimizer_passes is empty
  if (skip_optimizers) {
    std::vector<std::string> passes;
    const auto all_passes = onnx::optimization::GetFuseAndEliminationPass();
    for (const auto& pass : all_passes) {
      if (std::find(skip_optimizers->begin(), skip_optimizers->end(), pass) ==
          skip_optimizers->end()) {
        passes.push_back(pass);
      }
    }
    config.optimizer_passes = passes;
  }

  auto FoldConstant = constant_folding ? _FoldConstant : Identity; // 融合胶水节点主要看这个
  auto InferShapes = shape_inference ? _InferShapes : Identity;

  Check(model);
  auto OptAndShape =
      FixedPointFn(std::function{InferShapes}, std::function{Optimize}, 15);
  auto OptAndShapeAndFold =
      FixedPointFn(std::function{OptAndShape}, std::function{FoldConstant}, 15);
  auto sim_model = OptAndShapeAndFold(model);
  Check(sim_model);
  return sim_model;
}
  • 下面看下FoldConstant,从下面代码可以看到整个思路:
    • 基于GetConstantNodes函数得到所有的常量节点和非常量节点
    • 使用onnxruntime进行一次推理,得到常量节点的值,并将常量节点的值及节点名构造成initializer,存入model中
    • 清空model中所有的node
    • 加入非常量节点

这样得到一个新的model,可以表示为常量的节点都转换为initializer

onnx::ModelProto _FoldConstant(const onnx::ModelProto& model) {
  const auto& tmp = model;
  {
    onnx::ModelProto model;
    model.CopyFrom(tmp);
    const auto [const_nodes, non_const_nodes] = GetConstantNodes(model);
    for (const auto& x : const_nodes) {
      RunOpAndAddInitializer(model, x);
    }
    model.mutable_graph()->clear_node();
    for (const auto& x : non_const_nodes) {
      *model.mutable_graph()->add_node() = x;
    }
    return model;
  }
}

  • 看下GetConstantNodes的源码,了解下其怎么得到常量节点的,看第18-22行,可以知道,是基于initializers来判断,如果一个节点的输入都是initializer,那其输出也是一个initializer,然后将输出加入initializers,再判断下一个节点
std::pair<std::vector<onnx::NodeProto>, std::vector<onnx::NodeProto>>
GetConstantNodes(const onnx::ModelProto& model) {
  std::vector<std::string> const_names;
  std::vector<onnx::NodeProto> const_nodes;
  std::vector<onnx::NodeProto> non_const_nodes;
  std::transform(
      model.graph().initializer().begin(), model.graph().initializer().end(),
      std::back_inserter(const_names), [](const auto& x) { return x.name(); });
  // node is already topo sorted
  for (const auto& node : model.graph().node()) {
    // clang-format off
    if (IsOfficialOp(node.domain(), node.op_type()) &&
        IsDeterministic(node.domain(), node.op_type()) &&
        !IsQDQ(node.domain(), node.op_type()) &&
        !HasSubgraph(node) &&
        !ProduceLargeTensor(model, node, config.tensor_size_threshold) &&
        // clang-format on
        std::all_of(node.input().begin(), node.input().end(),
                    [&const_names](const auto& x) {
                      return std::find(const_names.begin(), const_names.end(),
                                       x) != const_names.end();
                    })) {
      const_names.insert(const_names.end(), node.output().begin(),
                         node.output().end());
      const_nodes.push_back(node);
    } else {
      non_const_nodes.push_back(node);
    }
  }
  return {const_nodes, non_const_nodes};
}
  • 看下RunOpAndAddInitializer是怎么实现的,可以看到RunOp,是基于每个输入的op的输入,输出,以及op type,构建一个只有当前op的模型图,然后推理,得到output,这样推理有一个好处,对于我们导出的自定义op,虽然onnxruntime不能推理,但也不影响单个op的推理优化
std::vector<onnx::TensorProto> RunOp(onnx::ModelProto& model,
                                     const onnx::NodeProto& op) {
  std::vector<std::string> input_names;
  std::vector<onnx::TensorProto> input_tps;

  for (const auto& input : op.input()) {
    if (std::find(input_names.begin(), input_names.end(), input) !=
        input_names.end()) {
      continue;
    }
    input_names.push_back(input);
    auto in_tp = FindInitializerByName(model, input);
    input_tps.push_back(in_tp);
  }
  onnx::ModelProto op_model;
  op_model.set_ir_version(model.ir_version());
  for (const auto& x : model.opset_import()) {
    *op_model.add_opset_import() = x;
  }
  *op_model.mutable_graph()->add_node() = op;
  for (const auto& x : input_names) {
    *op_model.mutable_graph()->add_input() = FindValueInfoProtoByName(model, x);
  }
  for (const auto& x : op.output()) {
    onnx::ValueInfoProto vi;
    // In principle output ValueInfoProto must have type. But it is not checked.
    vi.set_name(x);
    *op_model.mutable_graph()->add_output() = vi;
  }

  auto output_tps = ModelExecutor::Run(op_model, input_tps);
  for (size_t i = 0; i < op.output_size(); i++) {
    output_tps[i].set_name(op.output(i));
  }
  return output_tps;
}

void RunOpAndAddInitializer(onnx::ModelProto& model,
                            const onnx::NodeProto& op) {
  const auto output_tps = RunOp(model, op);
  for (const auto& output_tp : output_tps) {
    *model.mutable_graph()->add_initializer() = output_tp;
  }
}

你可能感兴趣的:(onnx,目标检测,深度学习)