pytorch模型转caffe

工具介绍

最近在尝试了一些pytorch模型转caffe模型的方法,推荐一个特别好用的工具,github上的开源项目https://github.com/xxradon/PytorchToCaffe.git

本项目最大的优点是可以支持pytorch各最新版本,对于0.3, 0.3.1, 0.4, 0.4.1 ,1.0, 1.2都有支持

同时项目还提供对caffe中prototxt每层形状和ops的分析工具

工具使用

请参考项目example/中的各例子

经测试该工具对YOLO v3, Deeplab v3+以及Enet等网络都可以做到比较准确的转换

代码分析

项目的主要实现在pytorch_to_caffe.py模块中

模块初始化时会创建Rp类的对象,并用这个对象覆盖pytorch中的层实现,例如卷积层的实现

F.conv2d=Rp(F.conv2d,_conv2d)

在工具使用中会调用pytorch网络的forward()方法,此时在调用到F.conv2d层是就会调用刚才覆盖的Rp(F.conv2d,_conv2d)这个对象中的__call__方法,并在此方法中调用_conv2d

_conv2d是工具内部定义的方法,作用是计算pytorch中的conv,将该层的名字以及计算得到的blob加入到之前创建的Translog中,并创建caffe中的conv实现,将pytorch中的相关权重写入caffe层中

自定义层

由之前的代码分析可知,如果需要用该工具将某个pytorch层转化为caffe层,只需要首先定义一个转换层函数如_conv2d, 用Rp类的对象覆盖原pytorch中的实现,调用pytorch的forward方法后新创建的caffe网络和权重都会保存在Translog对象的cnet属性中,此时再调用该模块的save_prototxt和save_caffemodel方法,就可以得到转换后caffe中的prototxt和caffemodel文件

自定以层和转换pytorch中已有的层的做法是一样的,通过一个例子说明一下添加自定义层的做法。比如在项目中我需要用到原版caffe中的deconv层来实现nearest插值

首先添加转换层函数

def _interpolate_deconv(raw, input,size=None, scale_factor=None, mode='nearest', align_corners=None):
    if mode != "nearest" or align_corners != None:
        raise NotImplementedError("not implement F.interpolate totally")
    x = raw(input,size , scale_factor ,mode)

    layer_name = log.add_layer(name='upsample')
    log.add_blobs([x], name='upsample_blob'.format(type))
    layer = caffe_net.Layer_param(name=layer_name, type='Deconvolution',
                                  bottom=[log.blobs(input)], top=[log.blobs(x)])
    layer.conv_param(x.size()[1],   kernel_size= 2, stride = 2,
                     pad = 0, dilation= 1, bias_term = False, groups = x.size()[1])
    layer.param.convolution_param.bias_term=False
    layer.add_data(np.ones([x.size()[1], 1, 2,2]))
    log.cnet.add_layer(layer)
    return x

再用Rp对象覆盖原来pytorch中的nearest插值的F.interpolate方法

F.interpolate = Rp(F.interpolate,_interpolate_deconv)

这样就实现了对nearest插值的转换

注意事项

使用工具时容易出现如下错误

TypeError: None has type NoneType, but expected one of: bytes, unicode

错误的原因是在运行过程中某一层输入的blob无法与之前已经转换成功的blob匹配,原因可能为

  1. pytorch的模型中包含该工具中没有实现覆盖的层
  2. 在使用工具时没有调用pytorch中的net.eval(),导致模型处于训练模式,出现额外的内部tensor运算也被当成网络层记录下来

你可能感兴趣的:(模型转换)