作者写这篇文章纯粹是太过心烦,找点东西记录一下。刚好手上在做这件事,于是记录一下。
本文包含两部分:
首先我们要明确两件事:
那么,接下来我会举一个简单的例子来说明:
import torch
import torch.onnx as onnx
# 可有可无,不是重点。定义一个简单的PyTorch模型,可以换成你自己的模型
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
def forward(self, x):
x = x.unsqueeze(0)
return x
# 可有可无,不是重点。创建一个示例模型实例,如果你有pth文件可以在这里加载
model = SimpleModel()
# 定义输入张量,这个要关注一下,张量的形状必须符合你模型的要输入的模型的张量的形状,这个input会在模型里完整的跑一遍
input_tensor = torch.randn(192, 36)
# 导出模型为ONNX格式
onnx_file_path = "result.onnx"
onnx.export(model, input_tensor, onnx_file_path)
总之,这样子之后我们就可以获得一个onnx模型了。这一小节结束!!!! 开心!!!!!
让人痛苦的事情开始了,我们为什么要转onnx?那当然是为了部署咯。那要是硬件不支持怎么办?
问得好!我带你们打!怎么会有这种令人头大的事情!!!! 这种时候就要开始准备改模型了,但是网上的算子对照表比较少,可能大部分模型都比较正常吧,我这个在网络里有unsqueeze,切片等等操作,这里记录一下咯。
再次声明:本人用的torch版本为1.8,目标onnx的opset_version为11,诸位如果有遇到下面没有记录到的算子,还请用上面的这段代码自己去尝试。onnx可视化可以用Netron来做:
pytorch的unsqueeze就对应着onnx的unsqueeze:
pytorch中的cat对应着onnx中的concat:
在这里我要多哔哔两句了,还记得我们前面说的onnx是先记录再转吗?clone在这里没有显式的结构可能就是因为它被整合到了concat里了,有错误的话欢迎指出。
torch中的expand对应的onnx的结构就是这样子了,比较令人意外的是,我原本以为参数不同结果会不一样,没想到都是长这个样子,那就对不住了,我只好说:实践是检验真理的唯一标准!!!!!
torch中repeat对应的onnx算子差不多就是下面这样子,话说结构居然比expand还要简单一些。
这一个的结构就比较多变了,进行不同的操作,组合在一起的形状总是不太一样,不过不管怎么变,总是会出现ScatterND这个算子,这次我会给出完整代码,感兴趣的话可以自己拿回去修改玩一玩:
import torch
import torch.onnx as onnx
# 定义一个简单的PyTorch模型
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
def forward(self, x):
# 主要改这里
x[0] = x[0] + 1
return x
# 创建一个示例模型实例
model = SimpleModel()
# 定义输入张量
input_tensor = torch.randn(192, 36)
# 导出模型为ONNX格式
onnx_file_path = "kkk.onnx"
onnx.export(model, input_tensor, onnx_file_path, opset_version=11)
pytorch中对张量的索引会对应着onnx的gather算子:
就先到这里吧,写的很乱,但是总归是记录了一点什么。希望能对你有帮助,润!