pytorch转onnx自定义算子

有时候我们在部署模型的时候,会用到一些自定义算子,通常这种情况会导致报错,从而无法转出onnx模型。通过自定义算子插件可以解决这个问题。代码如下:

import torch
import torch.nn as nn

class ScatterMax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, src):
        temp = torch.unique(src)
        # print(src.shape)
        # print(temp.shape)
        out = torch.zeros((temp.shape[0], src.shape[1]), dtype=torch.float32, device=src.device)
        return out
    @staticmethod
    def symbolic(g, src):
        return g.op("ScatterMaxPlugin", src)

class VFE(nn.Module):
    def __init__(self):
        super().__init__()
        self.pfn_layer0 = nn.Sequential(
            nn.Linear(in_features=10, out_features=64, bias=False),
            nn.BatchNorm1d(num_features=32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True),
        )
        # self.scatter = ScatterMax()
    def forward(self, x):
        x = self.pfn_layer0(x)
        x = ScatterMax.apply(x)
        return x

if __name__ == '__main__':
    pillarvfe = VFE()
    input = torch.zeros((40000, 32, 10))
    output = pillarvfe(input)
    # print(output.shape)

    torch.onnx.export(pillarvfe,
                      input,
                      "vfe.onnx",
                      export_params=True,
                      opset_version=11,
                      do_constant_folding=True,
                      keep_initializers_as_inputs=True,
                      input_names=["input"],
                      output_names=["output"],
                      operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

参考链接:pytorch自定义算子插件

你可能感兴趣的:(pytorch,pytorch,深度学习,python)