踩坑记录 | pytorch转ONNX时遇到的问题 tinynet smpl

1. torch环境问题

我在转换tinynet模型为onnx时,会产生如下报错

ImportError:cannot import name 'container_abcs'

这是由于torch版本不对,改为1.7.1就没有这个问题;因为1.8版本之后container_abcs就已经被移除了。但对于其他模型对这个版本又不适用,onnx对torch的版本是较为敏感的。

转换成功之后还需要验证onnx模型结果,随机生成输入tensor,观察输出结果是否一致,同样注意torch版本问题。

2. 去除不支持算子

在观察tinynex.onnx时发现了where算子,最终将问题定位在tinynet的forward函数中的三个if语句,调整后where算子消失。
踩坑记录 | pytorch转ONNX时遇到的问题 tinynet smpl_第1张图片

在转换smpl模型时发现了einsum算子,定位问题出现来源于torch.einsum函数,于是使用torch.mm重新写一下就解决了。

踩坑记录 | pytorch转ONNX时遇到的问题 tinynet smpl_第2张图片

解决后发现smpl模型同样存在where算子,注释掉if语句并没有效果,最后排查出问题存在于torch源码中存在的切片赋值操作,onnx不支持切片对象的赋值操作

rel_joints[:, 1:] -= joints[:, parents[1:]]

将这句话的代码逻辑使用了F.pad和torch.cat重新写,四个where均消失,但按上述修改后,发现onnx中最下面又多了一个if算子,因此还是将函数中的逻辑重写,去除F.pad和torch.cat操作,最后if节点消失。

综上所述,torch.einsum、切片的赋值操作onnx均不支持,关于if语句,应该有一定的影响。

你可能感兴趣的:(pytorch,深度学习,机器学习)