pytorch自带网络_Pytorch转ONNX-实战篇2(实战踩坑总结)

前两篇文章分别从理论和ONNX的核心机制描述了Pytorch转ONNX需要注意的事情。接下来这篇文章没有什么核心主旨,只是纯粹记录我当时做项目的时候踩的坑以及应对方案

(1)Pytorch2ONNX不支持对slice对象赋值

下面这段代码是不被Pytorch原生的onnx转换接口支持的,即不能对slice对象赋值

preds

仔细想想其实也比较合理,因为上面的操作也很难在DAG上被表示,因为并不仅仅是把preds中的那个区域取出来弄个新的变量,然后在上面+1,而是直接把preds的一部分改掉了。当时我负责MMSeg的slide inference转换的时候遇到了这个问题,解决方案如下:

preds += F.pad(crop_seg_logit,
               (int(x1), int(preds.shape[3] - x2), int(y1),
                int(preds.shape[2] - y2)))

即我对crop_seg_logit做了一个padding,把它变成了和preds一样的大小,这样我就直接变成了矩阵相加,没必要变成slice的操作了

这个方法自然很丑,而且会引出一个新的问题,那就是Pytorch生成的onnx padding的格式,onnx runtime接收的格式以及TensorRT需要的格式都不一样。这个就是之后的问题了(超纲了,不讲了)

这里具体的例子我懒得查了,以二维矩阵的填充为例。只记得一个转出来的是(begin0, begin1, end0, end1),另一个是(begin0, end0, begin1, end1)

这里面begin0代表第0维左边的填充数量,end0代表右边的填充数量

(2)resize

当时做segmentation模型的时候,最重要的就是resize操作。ONNX里面的resize要求output shape必须为常量(即tuple of int),因此不可以用tensor.Size作为输入,因为人家并不是tuple of int

if isinstance(size, torch.Size):
    size = tuple(int(x) for x in size)

所以我们必须手动粗暴的把torch.Size变成tuple of int

当时有reviewer吐槽我这个方法丑,要我改成tuple(size),说Pytorch重载了tuple,直接可以把torch.Size变成tuple of int。但是很诡异的是在正常情况下的确可以,但如果一旦进入了ONNX tracining模式,这个方法就失效了。我简单看了看,推测是因为对tuple的重载是在C++层面做的,而ONNX tracing也会涉及到一些C++层面的事情,也就是说ONNX tracing会重载一些C++的部分,可能正好就把tuple给抹掉了

(3) 应对kwargs的约束

pytorch自带的onnx转换api: torch.onnx.export,只支持args参数。一般来说调用这个api只需要提供model(喜闻乐见的nn.Module),调用model的参数args(也就是调用model.forwrd()的参数)以及导出的文件名f。然后这个函数就会内部执行一遍: model(*args),执行的时候做tracining

pytorch自带网络_Pytorch转ONNX-实战篇2(实战踩坑总结)_第1张图片

但是我们知道一般来说除了args,还需要kwargs,比如model(input, getloss=False),其中input就是args,False就是kwargs。OpenMMLab里面几乎所有的model都需要kwargs

为了绕开这个约束,我们需要利用python的partial函数,将model做个封装:

model

这样我们可以给model提供需要的kwargs,同时又可以原封不动的调用torch.onnx.export

注意,kwargs不能包括网络的输入,比如如果你想把input image放进args,那么得到的onnx就会是一个没有输入的图(它会把kwargs里面的input image当成一个常量)

(4)Pytorch和ONNX Runtime结果对齐

OpenMMLab系列提供了一个很有用的功能,就是自动比对Pytorch和ONNXRuntime的精度。这个功能可以帮助用户确定转出来的ONNX有没有问题。

然而之前也提到过,ONNXRuntime和Pytorch需要的ONNX格式不一样,而且有些计算也不一样,因此就算结果对不上,也不能代表什么

在某些操作上,ONNXRuntime和Pytorch的行为不一致。比如对一个一维tensor:[0,0,0]调用argmax,那么ONNXRuntime返回的是0,而Pytorch是1(举个例子,具体的差异我记不清了)

当时我在做Detection模型的自动比对的时候就遇到了问题,在经历了nms操作之后,bbox会根据score的大小做排序,但score相同的情况下,ONNXRuntime和Pytorch的结果就会有差异。因此我们最后只选择比对score,而不管bbox的dx,dy这些信息了

商汤科技的算法中台组常年招收正式/实习的员工,组里面有很多大佬,负责维护OpenMMlab框架,组里的老大就是网红 @陈恺,可以近距离和他接触(目前组里主要都在深圳,但是也欢迎base在北京,上海,香港的小伙伴)。目前有比较多的岗位,不仅仅负责训练,部署框架的开发,还会接触最新的算法(检测分割抠图视频追踪啥的),发论文什么的

有兴趣的小伙伴欢迎来联系我,可以私信,或者直接发简历到hanruobing@http://pku.edu.cn

你可能感兴趣的:(pytorch自带网络)