PyTorch自定义C ++和CUDA扩展

MASKRCNN_BENCHMARK

maskrcnn-benchmark是FaceBook开源的深度学习实例分割算法MASK RCNN实现,使用Python+Pytorch。

阅读代码maskrcnn-benchmark发现其使用了C++/Cuda/Python混编,使用的是pybind11并且还使用了apex混合精度训练,果真是深度学习集大成者。

这篇博客对maskrcnn-benchmark的源码框架做了大致的介绍,个人觉得非常具有借鉴意义,摘录过来:

代码中抽象出来一个model类GeneralizedRCNN。输入images,target到model中,然后返回loss,loss的计算都放在forward函数中了。外层的train函数,只是做iter循环、loss backward、optimizer.step、记录日志等这类比较固化的通用代码。

GeneralizedRCNN里面由三部分组成:backbone,RPN,Roi_head。网络的构建大量使用工厂方法,基本上可以根据配置来创建不同的检测网络,能够支持多种组合。

  • model GeneralizedRCNN
    • backbone nn.Sequential
      • body ResNet
        • stem StemWithFixedBatchNorm //resnet的基础层
        • module BottleneckWithFixedBatchNorm //resnet的bottleneck模块
      • fpn
        • inner_block 1*1的conv层
        • layer_block 3*3的conv层
      • backbone是RCNN的骨干网,就是图片的特征提取器,在resnet+fpn的网络下,提取出来的是各尺度(5个尺度)上的256-d的feature map,这个feature map作为下一阶段rpn的输入。
    • rpn RPNModule
      • head RPNHead //rpn的cnn网络,对featuremap每个点计算对应的bbox和cls_logits
      • anchor_generator AnchorGenerator //根据rpn_head的计算结果,找到最可能的bbox的坐标
      • box_selector_train RPNPostProcessor
      • loss_evaluator RPNLossComputation // 评估RPN网络的Loss
      • rpn是一个cnn网络,输入是feature map,输出是bbox坐标。
    • roi_heads CombinedROIHeads
      • box ROIBoxHead
        • feature_extractor FPN2MLPFeatureExtractor
          • pooler
          • fc6
          • fc7
        • predictor FPNPredictor
          • cls_score nn.Linear
          • bbox_pred nn.Linear
        • post_processor
        • loss_evaluator
      • mask ROIMaskHead

整体上,就是backbone算特征,rpn算框框, roi_heads算物体类别。外部看封装的很好,细节都藏在具体的各个类里面了。

PyTorch自定义C ++和CUDA扩展

参考:PyTorch官方自定义扩展

在读到maskrcnn-benchmark/maskrcnn_benchmark/csrc/cpu/nms_cpu.cpp的时候,发现该仓库将nms、roi align等操作进行了C++/CUDA实现,并使用python setup.py install打包成python扩展maskrcnn_benchmark._C
关于为什么将这几个操作用C++实现,个人考虑是类似于NMS这样的操作使用频率过高,通常都需要对上万个候选框进行计算,如果能用C++/CUDA实现,效率可能有所提升,大家可以自行查看代码,libtorch提供的张量操作接口与pytorch几乎一致,写起来应该不会太陌生;libtorch也是通过pybind11来实现C++/Python混合编程的,棒。

PyTorch官方提供了自定义C++和CUDA扩展的实现方法,以LLTM(Long-Long-Term-Memory)为例,其forward/backward过程如下:

  • lltm.cpp
/*
 * lltm.cpp
 */
#include 
#include 
#include 
#include 

torch::Tensor d_sigmoid(torch::Tensor z) {
     
    auto s = torch::sigmoid(z);
    return (1 - s) * s;
}

std::vector<at::Tensor> lltm_forward(
        at::Tensor input,
        at::Tensor weights,
        at::Tensor bias,
        at::Tensor old_h,
        at::Tensor old_cell) {
     
    auto X = at::cat({
     old_h, input}, /*dim=*/1);

    auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1));
    auto gates = gate_weights.chunk(3, /*dim=*/1);

    auto input_gate = at::sigmoid(gates[0]);
    auto output_gate = at::sigmoid(gates[1]);
    auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0);

    auto new_cell = old_cell + candidate_cell * input_gate;
    auto new_h = at::tanh(new_cell) * output_gate;

    return {
     new_h,
            new_cell,
            input_gate,
            output_gate,
            candidate_cell,
            X,
            gate_weights};
}

// tanh'(z) = 1 - tanh^2(z)
at::Tensor d_tanh(at::Tensor z) {
     
    return 1 - z.tanh().pow(2);
}

// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) {
     
    auto e = z.exp();
    auto mask = (alpha * (e - 1)) < 0;
    return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}

std::vector<at::Tensor> lltm_backward(
        at::Tensor grad_h,
        at::Tensor grad_cell,
        at::Tensor new_cell,
        at::Tensor input_gate,
        at::Tensor output_gate,
        at::Tensor candidate_cell,
        at::Tensor X,
        at::Tensor gate_weights,
        at::Tensor weights) {
     
    auto d_output_gate = at::tanh(new_cell) * grad_h;
    auto d_tanh_new_cell = output_gate * grad_h;
    auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;

    auto d_old_cell = d_new_cell;
    auto d_candidate_cell = input_gate * d_new_cell;
    auto d_input_gate = candidate_cell * d_new_cell;

    auto gates = gate_weights.chunk(3, /*dim=*/1);
    d_input_gate *= d_sigmoid(gates[0]);
    d_output_gate *= d_sigmoid(gates[1]);
    d_candidate_cell *= d_elu(gates[2]);

    auto d_gates =
            at::cat({
     d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);

    auto d_weights = d_gates.t().mm(X);
    auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);

    auto d_X = d_gates.mm(weights);
    const auto state_size = grad_h.size(1);
    auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
    auto d_input = d_X.slice(/*dim=*/1, state_size);

    return {
     d_old_h, d_input, d_weights, d_bias, d_old_cell};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
     
    m.def("forward", &lltm_forward, "LLTM forward");
    m.def("backward", &lltm_backward, "LLTM backward");
}
  • setup.py
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, BuildExtension

setup(name='lltm',
      ext_modules=[CppExtension('lltm', ['lltm.cpp'])],
      cmdclass={
     'build_ext': BuildExtension})

执行python setup.py install指令进行编译,如无报错则可尽情享用

$ipython
Python 3.7.3 (default, Mar 27 2019, 16:54:48) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.8.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch                                                                                                                                               

In [2]: ls                                                                                                                                                         
build/             dist/              lltm.cpp           lltm.egg-info/     lltm_cpp.egg-info/ setup.py

In [3]: import lltm                                                                                                                                                

In [4]: dir(lltm)                                                                                                                                                  
Out[4]: 
['__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'backward',
 'forward']

TroubleShooting

  1. Error Message
$ipython
Python 3.7.3 (default, Mar 27 2019, 16:54:48) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.8.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import lltm                                                                                                                                                
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
-input-1-e0ee809aa11b> in 
----> 1 import lltm

ImportError: dlopen(/anaconda3/lib/python3.7/site-packages/lltm-0.0.0-py3.7-macosx-10.7-x86_64.egg/lltm.cpython-37m-darwin.so, 2): Symbol not found: _THPVariableClass
  Referenced from: /anaconda3/lib/python3.7/site-packages/lltm-0.0.0-py3.7-macosx-10.7-x86_64.egg/lltm.cpython-37m-darwin.so
  Expected in: flat namespace
 in /anaconda3/lib/python3.7/site-packages/lltm-0.0.0-py3.7-macosx-10.7-x86_64.egg/lltm.cpython-37m-darwin.so

解决方案参考github issue:
这里需要注意的是在导入自己的扩展时需import torchimport lltm

  1. 编译报错
torch/lib/include/c10/Device.h:109:8: error:
      redefinition of 'hash'

torch/lib/include/c10/util/UniqueVoidPtr.h:108:54: error:
      no type named 'nullptr_t' in namespace 'std'

参考github issue,解决方案为

export MACOSX_DEPLOYMENT_TARGET=10.11

你可能感兴趣的:(图像处理,AI,深度学习,c++,pytorch,maskrcnn)