PyTorch(八)——pyTorch-To-Caffe

目录连接
(1) 数据处理
(2) 搭建和自定义网络
(3) 使用训练好的模型测试自己图片
(4) 视频数据的处理
(5) PyTorch源码修改之增加ConvLSTM层
(6) 梯度反向传递(BackPropogate)的理解
(7) 模型的训练和测试、保存和加载
(8) pyTorch-To-Caffe
(总) PyTorch遇到令人迷人的BUG

PyTorch的学习和使用(八)

Mon 22 Mon 29 Mon 05 settrace grad_fn pyTorch1.0 pyTorchToCaffe 完成进度表

一、pyTorch to Caffe ⇒ \Rightarrow 动态图到静态图的转换

静态图: 网络在输入数据前就预先将网络定义好,与数据无关,将所有的操作过程定义好,运行时填入数据。比如caffe,tensorFlow等框架。
动态图: 数据在网络传输中动态的构建网络。比如pyTorch等。

两者各有优点,静态图由于提前将网络结构确定了,部署十分方便,但是由于数据在网络中的传递过程往往不可知,因此调试较为困难;动态图是根据数据的流动动态的构建网络图,因此数据在网络中的状态都是已知的,调试十分便捷。目前,两者都在吸取对方的优点,tensorFlow也在也如动态图的机制,pyTorch与Caffe2结合,在配合ONNX实现高效的部署。

pyTorch模型转换到caffe模型可以看为动态图到静态图之间的转换,主要需要进动态图到静态图之间的转换,即构建出动态图然后将其映射到静态图,并且将网络参数也进行转换。

二、python trace 机制捕获动态图

pyTroch框架采用python构建,通过使用python的trace机制可以获取到网络在传递过程中所经过的结构,从而映射到静态图。主要步骤如下:

  1. 启动python的trace功能,并定义其回调函数。
  2. 在回调函数中捕获网络所调用的原子操作。
  3. 将对应的操作使用caffe的python接口进行映射。
  4. 将相应pyTorch的网络参数映射到caffe模型。
  5. 保存caffe模型,关闭python的trace功能。

代码框架如下:

import sys
import torch
from caffe import layers as L, params as P, to_proto

def tracea_fun(frame, event, arg):
    //通过当前的frame栈得到每次调用的函数,并将其转换为相应的caffe调用
    
def main(model, input):
    sys.settrace(trace_fun)
    output = model(input)
    sys.settrace(None)

if __name__ == "__main__":
    input = DataLoder()
    model = Net()
    
    main(model, input)

2.1 sys.settrace 操作捕获

python的sys.settrace定义docs.python.org:

Set the system’s trace function, which allows you to implement a Python source code debugger in Python.
Trace functions should have three arguments: frame, event, and arg. frame is the current stack frame. event is a string: ‘call’, ‘line’, ‘return’ or ‘exception’. arg depends on the event type.

因此通过sys.settrace的回调函数中的frame栈可以捕获当前的操作,其中frameframe objects,定义见The standard type hierarchy,常用属性有:

  • f_code: The code object being executed in this frame
    • co_name: Function name
    • co_varnames: A tuple containing the names of the local variables
  • f_locals: The dictionary used to look up local variables
  • f_back: The previous stack frame

则通过frame.f_code.co_namefrmae.f_locals可以获得网络传递过程中的函数名和参数。

2.2 pyTorch原子操作捕获

实现该方法的难点在于如何找到网络中数据的流向,比如进行的view操作和resNet网络中何时进行add操作,这些操作在pyTorch0.2中都封装成了相应的原子操作,只需要找到对应的调用函数即可(但是在pyToch0.3以上中直接调用C的接口,暂时不会怎么使用settrace进行捕捉)。

以卷基层为例,在trace_fun中的conv2d代码如下:

def trace_fun(frame, event, arg):
    if frame.f_code.co_name == "conv2d":
        groups = frame.f_locals["groups"]
        pad_h = frame.f_locals["padding"][0]
        pad_w = frame.f_locals["padding"][1]
        stride_h = frame.f_locals["stride"][0]
        stride_w = frame.f_locals["stride"][1]
        dilation = frame.f_locals["dilation"]
        weight = frame.f_locals["weight"]
        bias = frame.f_locals["bias"]
        bottom = getBottom()
        name = "conv1"
        
        top = L.Convolution(bottom, name=name,
                            kernel_h=kernel_h, kernel_w = kernel_w,
                            num_output=num_output, groups=groups,
                            stride_h=stride_h, stride_w=stride_w,
                            pad_h=pad_h, pad_w=pad_w,
                            dilation=dilation)

其中,getBottom()为获取当前层的前一层,通过维护一个容器,在容器中以每层的物理地址作为该层的唯一索引进行检索,即使用id(feature)来确定其前一层。

注意,该方法只用于第一个pyTorch0.2之前的版本,在0.3之后的版本通过直接调用C接口的方式,目前还不会将其操作栈剥离出来。

三、pyTorch grad_fn网络拓扑图构建

四、pyTorch1.0 ONNX和caffe2之间的使用

你可能感兴趣的:(PyTorch)