【pytorch】——自定义一个算子并导出到onnx

pytorch, onnx

摘要:为了将自定义算子的参数,或者自己想要保存的参数序列化到onnx中。

code

import torch
import torch.nn as nn
from torch.autograd import Function
import onnx
import torch.onnx

class Requant_(Function):
    @staticmethod
    def forward(ctx, input, requant_scale, shift):               # ctx 必须要
        input = input.double() * requant_scale / 2**shift        # 为了等价于c中的移位操作。会存在int32溢出
        input = torch.floor(input).float()

        return torch.floor(input)
    
    @staticmethod
    def symbolic(g, *inputs):
        return g.op("Requant", inputs[0], scale_f=23.0, shift_i=8)

requant_ = Requant_.apply

class TinyNet(nn.Module):
    def __init__(self):
        super(TinyNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = x.view(-1)
        x = requant_(x, 5, 5)
        return x

net = TinyNet().cuda()
ipt = torch.ones(2,3,12,12).cuda()
torch.onnx.export(net, (ipt,), 'tinynet.onnx', opset_version=11, enable_onnx_checker=False)
print(onnx.load('tinynet.onnx'))

关键点:

  • 继承自torch.autograd
  • scale_f=23.0, shift_i=8,_f表示浮点数,_i表示整形int32类型

onnx 模型
【pytorch】——自定义一个算子并导出到onnx_第1张图片

总结

这种是在pytorch中新写一个op,并序列化到onnx中,另外一个想法是:如果修改已有op的onnx序列化,比如conv2d,upsample等。得到onnx模型中,还需要对onnx模型解析,在把onnx模型转换成自己想要的表达。

你可能感兴趣的:(pytorch,pytorch,深度学习,机器学习)