torch自定义算子转onnx模型报错

报错如下

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1195, in _export
    _C._check_onnx_proto(proto, full_check=True)
RuntimeError: No Op registered for MYSELU with domain_version of 11

==> Context: Bad node spec for node. Name: MYSELU_2 OpType: MYSELU

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "get_onnx.py", line 69, in 
    torch.onnx.export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 350, in export
    return utils.export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 163, in export
    _export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1197, in _export
    raise torch.onnx.CheckerError(e)
torch.onnx.CheckerError: No Op registered for MYSELU with domain_version of 11

==> Context: Bad node spec for node. Name: MYSELU_2 OpType: MYSELU

报错代码

import torch
import torch.nn as nn
import torch.onnx
import torch.autograd
import os

# 定义op
class MYSELUImpl(torch.autograd.Function):
    @staticmethod
    def symbolic(g, x, p) -> torch._C.Value:
        return g.op("MYSELU", x, p, 
            g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.float32)),
            attr1_s="属性"
        )
    
    @staticmethod
    def forward(ctx, x, p):
        return x* 1 / (1 + torch.exp(-x))

class MySelu(nn.Module):
    def __init__(self, n) -> None:
        super().__init__()
        self.param = nn.parameter.Parameter(torch.arange(n).float())

    def forward(self, x):
        # 按官方的说法不能用forwrad,只能用apply
        return MYSELUImpl.apply(x, self.param)


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 3, 3, padding=1)
        self.conv.weight.data.fill_(1)
        self.conv.bias.data.fill_(0)

        self.myselu = MySelu(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.myselu(x)
        return x


# 这个包对应opset11的导出代码,如果想修改导出的细节,可以在这里修改代码
# import torch.onnx.symbolic_opset11
print("对应opset文件夹代码在这里:", os.path.dirname(torch.onnx.__file__))

model = Model().eval()
input = torch.tensor([
    # batch 0
    [
        [1,   1,   1],
        [1,   1,   1],
        [1,   1,   1],
    ],
        # batch 1
    [
        [-1,   1,   1],
        [1,   0,   1],
        [1,   1,   -1]
    ]
], dtype=torch.float32).view(2, 1, 3, 3)

output = model(input)
print(f"inference output = \n{output}")

dummy = torch.zeros(1, 1, 3, 3)
torch.onnx.export(
    model,
    dummy,
    "myselu.onnx",
    input_names=["image"],
    output_names=["output"],
    opset_version=11,
    verbose=True,
    dynamic_axes={
        "image": {0:"batch"},
        "output": {0:"batch"}
    },
    # operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
)
print("Done.!")

解决方法,链接ONNX custom operator runtime error - #2 by sksenthilkumar - PyTorch Forums

torch.onnx.export中添加operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK

或者把pytorch版本退到1.1,1.2

或者添加enable_onnx_checker=False,只不过这个参数已被弃用和忽略

你可能感兴趣的:(python,深度学习,人工智能)