基本思想:学习tensorRT教程,来自bilibi ,参考附录一
一、代码:
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv=nn.Conv2d(1,1,3,stride=1,padding=1,bias=1)
self.conv.weight.data.fill_(0.3)
self.conv.bias.data.fill_(0.2)
def forward(self,x):
x=self.conv(x)
print(x.shape)
x=x.view(x.size(0),-1)
print(x.shape)
return x
model=Model().eval()
x=torch.full((1,1,3,3),1.0)
y=model(x)
torch.onnx.export(model,(x,),"example.onnx",verbose=True)
模型结果
修改一
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv=nn.Conv2d(1,1,3,stride=1,padding=1,bias=1)
self.conv.weight.data.fill_(0.3)
self.conv.bias.data.fill_(0.2)
def forward(self,x):
x=self.conv(x)
print(x.shape)
print(x.size(0))
x=x.view(int(x.size(0)),-1)
print(x.shape)
return x
model=Model().eval()
x=torch.full((1,1,3,3),1.0)
y=model(x)
torch.onnx.export(model,(x,),"example.onnx",verbose=True)
结果
代码
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv=nn.Conv2d(1,1,3,stride=1,padding=1,bias=1)
self.conv.weight.data.fill_(0.3)
self.conv.bias.data.fill_(0.2)
def forward(self,x):
x=self.conv(x)
print(x.shape)
print(x.numel())
x=x.view(-1,int(x.numel()//x.size(0)))
print(x.shape)
return x
model=Model().eval()
x=torch.full((1,1,3,3),1.0)
y=model(x)
torch.onnx.export(model,(x,),"example.onnx",verbose=True)
结果
使用大老师的onnxsim也可以直接简化
python3 -m onnxsim example.onnx example_sim.onnx
Simplifying...
Finish! Here is the difference:
┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓
┃ ┃ Original Model ┃ Simplified Model ┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩
│ Concat │ 1 │ 0 │
│ Constant │ 1 │ 0 │
│ Conv │ 1 │ 1 │
│ Gather │ 1 │ 0 │
│ Reshape │ 1 │ 1 │
│ Shape │ 1 │ 0 │
│ Unsqueeze │ 1 │ 0 │
│ Model Size │ 588.0B │ 419.0B │
└────────────┴────────────────┴──────────────────┘
结果
二、yolov5-5.0的转模型 Releases · ultralytics/yolov5 · GitHub
ubuntu@ubuntu:~/yolov5-5.0$ python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
修改源码
#bs, _, ny, nx = x[i].shape
bs, _, ny, nx =map(int, x[i].shape)
三、修改第三条
源码
#z.append(y.view(bs, -1, self.no)
z.append(y.view(-1,int(y.size(1)*y.size(2)*y.size(3)) , self.no))
需要将reshape的第一个维度改成-1,继续改
#x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
x[i] = x[i].view(-1, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
四、插播ncnn的转模型,就只剩下slice报错了,去掉了gather和shape节点
ubuntu@ubuntu:~/ncnn/build/install/bin$ ./onnx2ncnn /home/ubuntu/sxj_demo/yolov5-5.0/weights/yolov5s.onnx /home/ubuntu/sxj_demo/yolov5-5.0/weights/yolov5s.param /home/ubuntu/sxj_demo/yolov5-5.0/weights/yolov5s.bin
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !
Unsupported slice step !