pytorch导出onnx时遇到不支持的算子怎么解决

在使用pytorch模型训练完成之后,我们现在使用的比较多的一种方法是将pytorch模型转成onnx格式的模型中间文件,然后再根据使用的硬件来生成具体硬件使用的深度学习模型,比如TensorRT。
在从pytorch模型转为onnx时,我们可能会遇到部分算子无法转换的问题,本篇注意记录下解决方法。

在导出onnx时,如果出现报错的算子,可以先在下面的链接中查找onnx算子是否支持
https://github.com/onnx/onnx/blob/main/docs/Operators.md

pytorch中有,onnx中也有的算子

导出时使用的onnx op 版本低导致

这个就好解决了,把op库的版本提高就行,但是有可能提高了版本以后,又出现了原来支持的算子现在又不支持了,这个再说

pytorch中没有注册某个onnx算子

如果是这种情况,就按照下面的方式进行:

from torch.onnx import register_custom_op_symbolic
# 创建一个asinh算子的symblic,符号函数,用来登记
# 符号函数内部调用g.op, 为onnx计算图添加Asinh算子
#   g: 就是graph,计算图
#   也就是说,在计算图中添加onnx算子
#   由于我们已经知道Asinh在onnx是有实现的,所以我们只要在g.op调用这个op的名字就好了
#   symblic的参数需要与Pytorch的asinh接口函数的参数对齐
#       def asinh(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...
def asinh_symbolic(g, input, *, out=None):
    return g.op("Asinh", input)

# 在这里,将asinh_symbolic这个符号函数,与PyTorch的asinh算子绑定。也就是所谓的“注册算子”
# asinh是在名为aten的一个c++命名空间下进行实现的

# aten是"a Tensor Library"的缩写,是一个实现张量运算的C++库
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)

另外一个写法
这个是类似于torch/onnx/symbolic_opset*.py中的写法
通过torch._internal中的registration来注册这个算子,让这个算子可以与底层C++实现的aten::asinh绑定
一般如果这么写的话,其实可以把这个算子直接加入到torch/onnx/symbolic_opset*.py中

import functools
from torch.onnx import register_custom_op_symbolic
from torch.onnx._internal import registration

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)

@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):
    return g.op("Asinh", input)

pytorch中有,onnx中无的算子

继承torch.autograd.Function实现自定义算子

import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolic

OperatorExportTypes = torch._C._onnx.OperatorExportTypes

class CustomOp(torch.autograd.Function):
    @staticmethod 
    def symbolic(g: torch.Graph, x: torch.Value) -> torch.Value:
        return g.op("custom_domain::customOp2", x)

    @staticmethod
    def forward(ctx, x: torch.Tensor) -> torch.Tensor:
        ctx.save_for_backward(x)
        x = x.clamp(min=0)
        return x / (1 + torch.exp(-x))

customOp = CustomOp.apply

然后再自己实现custom_domain::customOp2这个算子,如果用TensorRT,就需要自己实现一个插件。

你可能感兴趣的:(模型部署,pytorch,人工智能,python,深度学习,经验分享,笔记)