onnx-simplifier对于经常转模型的同学来说,是一个非常方便好用的工具,它可以对pytorch、tf或者paddle转换得到的onnx模型做优化,去除很多胶水节点,以及做一些图优化,得到一个简洁明了的模型图
整个工程的代码实现很简单,老版本是python实现,只有一个文件,主要功能就两个函数,新版本换成了cpp实现,下面以新版cpp的源码,讲下优化的原理
onnx::ModelProto Simplify(
const onnx::ModelProto& model,
std::optional<std::vector<std::string>> skip_optimizers,
bool constant_folding, bool shape_inference, size_t tensor_size_threshold) {
config.tensor_size_threshold = tensor_size_threshold;
config.optimizer_passes.clear();
// skip_optimizers == nullopt means skiping all optimizers, so
// config.optimizer_passes is empty
if (skip_optimizers) {
std::vector<std::string> passes;
const auto all_passes = onnx::optimization::GetFuseAndEliminationPass();
for (const auto& pass : all_passes) {
if (std::find(skip_optimizers->begin(), skip_optimizers->end(), pass) ==
skip_optimizers->end()) {
passes.push_back(pass);
}
}
config.optimizer_passes = passes;
}
auto FoldConstant = constant_folding ? _FoldConstant : Identity; // 融合胶水节点主要看这个
auto InferShapes = shape_inference ? _InferShapes : Identity;
Check(model);
auto OptAndShape =
FixedPointFn(std::function{InferShapes}, std::function{Optimize}, 15);
auto OptAndShapeAndFold =
FixedPointFn(std::function{OptAndShape}, std::function{FoldConstant}, 15);
auto sim_model = OptAndShapeAndFold(model);
Check(sim_model);
return sim_model;
}
这样得到一个新的model,可以表示为常量的节点都转换为initializer
onnx::ModelProto _FoldConstant(const onnx::ModelProto& model) {
const auto& tmp = model;
{
onnx::ModelProto model;
model.CopyFrom(tmp);
const auto [const_nodes, non_const_nodes] = GetConstantNodes(model);
for (const auto& x : const_nodes) {
RunOpAndAddInitializer(model, x);
}
model.mutable_graph()->clear_node();
for (const auto& x : non_const_nodes) {
*model.mutable_graph()->add_node() = x;
}
return model;
}
}
std::pair<std::vector<onnx::NodeProto>, std::vector<onnx::NodeProto>>
GetConstantNodes(const onnx::ModelProto& model) {
std::vector<std::string> const_names;
std::vector<onnx::NodeProto> const_nodes;
std::vector<onnx::NodeProto> non_const_nodes;
std::transform(
model.graph().initializer().begin(), model.graph().initializer().end(),
std::back_inserter(const_names), [](const auto& x) { return x.name(); });
// node is already topo sorted
for (const auto& node : model.graph().node()) {
// clang-format off
if (IsOfficialOp(node.domain(), node.op_type()) &&
IsDeterministic(node.domain(), node.op_type()) &&
!IsQDQ(node.domain(), node.op_type()) &&
!HasSubgraph(node) &&
!ProduceLargeTensor(model, node, config.tensor_size_threshold) &&
// clang-format on
std::all_of(node.input().begin(), node.input().end(),
[&const_names](const auto& x) {
return std::find(const_names.begin(), const_names.end(),
x) != const_names.end();
})) {
const_names.insert(const_names.end(), node.output().begin(),
node.output().end());
const_nodes.push_back(node);
} else {
non_const_nodes.push_back(node);
}
}
return {const_nodes, non_const_nodes};
}
std::vector<onnx::TensorProto> RunOp(onnx::ModelProto& model,
const onnx::NodeProto& op) {
std::vector<std::string> input_names;
std::vector<onnx::TensorProto> input_tps;
for (const auto& input : op.input()) {
if (std::find(input_names.begin(), input_names.end(), input) !=
input_names.end()) {
continue;
}
input_names.push_back(input);
auto in_tp = FindInitializerByName(model, input);
input_tps.push_back(in_tp);
}
onnx::ModelProto op_model;
op_model.set_ir_version(model.ir_version());
for (const auto& x : model.opset_import()) {
*op_model.add_opset_import() = x;
}
*op_model.mutable_graph()->add_node() = op;
for (const auto& x : input_names) {
*op_model.mutable_graph()->add_input() = FindValueInfoProtoByName(model, x);
}
for (const auto& x : op.output()) {
onnx::ValueInfoProto vi;
// In principle output ValueInfoProto must have type. But it is not checked.
vi.set_name(x);
*op_model.mutable_graph()->add_output() = vi;
}
auto output_tps = ModelExecutor::Run(op_model, input_tps);
for (size_t i = 0; i < op.output_size(); i++) {
output_tps[i].set_name(op.output(i));
}
return output_tps;
}
void RunOpAndAddInitializer(onnx::ModelProto& model,
const onnx::NodeProto& op) {
const auto output_tps = RunOp(model, op);
for (const auto& output_tp : output_tps) {
*model.mutable_graph()->add_initializer() = output_tp;
}
}