deformable_conv可变形卷积源码解析

注意:本文源码来源于openlab中的mmcv

变形卷积源码主要有三个文件:

  • deform_conv.cpp: 位于/mmcv/ops/csrc/pytorch/deform_conv.cpp
  • deform_conv_cuda.cu:位于/mmcv/ops/csrc/pytorch/deform_conv_cuda.cu
  • deform_conv_cuda_kernel.cuh:位于/mmcv/ops/csrc/deform_conv_cuda_kernel.cuh

前向传播

首先查看deform_conv.cpp中的函数deform_conv_forward()

函数输入主要有特征图input,卷积权重weight, 偏置offset,输出特征图output等

函数调用了deform_conv_forward_cuda()函数

void deform_conv_forward(Tensor input, Tensor weight, Tensor offset,
                         Tensor output, Tensor columns, Tensor ones, int kW,
                         int kH, int dW, int dH, int padW, int padH,
                         int dilationW, int dilationH, int group,
                         int deformable_group, int im2col_step) {
  if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
    CHECK_CUDA_INPUT(input);
    CHECK_CUDA_INPUT(offset);
    CHECK_CUDA_INPUT(weight);
    CHECK_CUDA_INPUT(output);
    CHECK_CUDA_INPUT(columns);
    CHECK_CUDA_INPUT(ones);

    deform_conv_forward_cuda(input, weight, offset, output, columns, ones, kW,
                             kH, dW, dH, padW, padH, dilationW, dilationH,
                             group, deformable_group, im2col_step);
#else
    AT_ERROR("DeformConv is not compiled with GPU support");
#endif
  } else {
    AT_ERROR("DeformConv is not implemented on CPU");
  }
}

进一步调用了函数DeformConvForwardCUDAKernelLauncher(),位于文件deform_conv_cuda.cu中

void deform_conv_forward_cuda(Tensor input, Tensor weight, Tensor offset,
                              Tensor output, Tensor columns, Tensor ones,
                              int kW, int kH, int dW, int dH, int padW,
                              int padH, int dilationW, int dilationH, int group,
                              int deformable_group, int im2col_step) {
  DeformConvForwardCUDAKernelLauncher(
      input, weight, offset, output, columns, ones, kW, kH, dW, dH, padW, padH,
      dilationW, dilationH, group, deformable_group, im2col_step);
}

这是变形卷积核心函数,常规卷积是特征图上每个点与邻域点进行卷积操作,而变形卷积多了一个offset偏置,重新获取新的邻域点特征,再进行卷积,而不是邻接的点,所以变形卷积主要分为两个步骤

1)根据offset收集新的邻域特征;

2)再进行卷积。

其中第一步骤收集邻域特征相对比较麻烦,主要思路是根据offset找到新的邻域点,再拼接到中心点特征上,使得通道数由C变成C*kh*kw,是由下文代码中的deformable_im2col()函数实现。

第二步就很简单了,用一个维度为(C2, C1, 1, 1)卷积就可以实现。

void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight,  //输入特征图 input:(B, C1, H, W),卷积权重 weight:(C2, C1/group, kh, kw)
                                         Tensor offset, Tensor output,  // 坐标偏置offset:(B, deform_group*2*kh*kw, h, w), 输出特征output:(B, C2, h, w)
                                         Tensor columns, Tensor ones, int kW, // kW,Kh为卷积核大小
                                         int kH, int dW, int dH, int padW, // dW,dH为卷积步长stride
                                         int padH, int dilationW, int dilationH, // 
                                         int group, int deformable_group, //  变形卷积中一般每个通道公用一个坐标偏置,也可以几个通道维度公用一个坐标偏置,那么每个通道就会分为deformable_group个组数
                                         int im2col_step) {  // step
  // todo: resize columns to include im2col: done
  // todo: add im2col_step as input
  // todo: add new output buffer and transpose it to output (or directly
  // transpose output) todo: possibly change data indexing because of
  // parallel_imgs

  deform_conv_shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH,
                          padW, dilationH, dilationW, group, deformable_group);
  at::DeviceGuard guard(input.device());

  int batch = 1;
  if (input.ndimension() == 3) {
    // Force batch
    batch = 0;
    input.unsqueeze_(0);
    offset.unsqueeze_(0);
  }

  // todo: assert batchsize dividable by im2col_step

  long batchSize = input.size(0); // B
  long nInputPlane = input.size(1); // 输入通道数 C1
  long inputHeight = input.size(2); //输入高 H
  long inputWidth = input.size(3); // 输入宽 W

  long nOutputPlane = weight.size(0); // 输出通道数 C2

  long outputWidth =
      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; // 输出宽 w
  long outputHeight =
      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; // 输入 h

  TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");

  output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
                        outputHeight, outputWidth}); //  (B, C2, h, w)->(B/step, step, C2, h, w)
  columns = at::zeros(
      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
      input.options());  // (C1*kw*kh, step*h*w),其中 C1*kw*kh可以看作将输入特征图每个点的邻域特征汇聚在一起,邻域个数是kw*kh,每个邻域点的通道数都是C1,所以总的就是C1*kw*kh

  if (ones.ndimension() != 2 ||
      ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
    ones = at::ones({outputHeight, outputWidth}, input.options()); // (h, w)
  }

  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
                      inputHeight, inputWidth}); // (B, C1, H, W)->(B/step, step, C1, H, W)
  offset = offset.view({batchSize / im2col_step, im2col_step,
                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});  // (B, deform_group*2*kh*kw, h, w)->(B/step, step, deform_group*2*kh*kw, h, w)

  Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane,
                                    im2col_step * outputHeight, outputWidth},
                                   output.options()); // (B/step, C2, step*h, w),元素全为0

  output_buffer = output_buffer.view(
      {output_buffer.size(0), group, output_buffer.size(1) / group,
       output_buffer.size(2), output_buffer.size(3)}); // (B/step, C2, step*h, w)->(B/step, group, C2/group, step*h, w)

// 分成B/step个分别进行运算
  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
      // 该函数的详细分析见后文,是为了获取columns值, columns可以理解为输入特征图上每个特征点根据offset汇总邻域点特征到自己维度上
    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,  //input[elt]:(step, C1, H, W), offset[elt]:(step, deform_group*2*kw*kh, h, w)
                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,  
                      dilationW, im2col_step, deformable_group, columns); //  columns:(C1*kw*kh, step*h*w)

    columns = columns.view({group, columns.size(0) / group, columns.size(1)});  // (C1*kw*kh, step*h*w)->(group, C1*kh*kw/group, step*h*w)
    weight = weight.view({group, weight.size(0) / group, weight.size(1), 
                          weight.size(2), weight.size(3)});  //(C2, C1/group, kh, kw)-> (group, C2/group, C1/group, kh, kw)

   // 分成group个分别进行运算
    for (int g = 0; g < group; g++) {
        //columns是汇总的特征,再与weight进行卷积来获取输出特征output_buffer
        //addmm_()是对内部矩阵乘积结果进行相加
        // flatten(1)对1维和之后维度进行展平
      output_buffer[elt][g] = output_buffer[elt][g]
                                  .flatten(1) // (C2/group, step*h, w)->(C2/group, step*h*w)
                                  .addmm_(weight[g].flatten(1), columns[g])  // (C2/group, C1/group*kh*kw) mm (C1/group*kh*kw, step*h*w)->(C2/group, step*h*w)
                                  .view_as(output_buffer[elt][g]);  // (C2/group, step*h*w)->(C2/group, step*h, w)
    }
    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)}); // (group, C1*kh*kw/group, step*h*w)->(C1*kw*kh, step*h*w) 
  }


  output_buffer = output_buffer.view(
      {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
       output_buffer.size(3), output_buffer.size(4)});  // (B/step, group, C2/group, step*h, w)->(B/step, C2, step*h, w)

  output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
                                      im2col_step, outputHeight, outputWidth});  //(B/step, C2, step*h, w)-> (B/step, C2, step, h, w)
  output_buffer.transpose_(1, 2);  //(B/step, C2, step, h, w)-> (B/step, step, C2, h, w)
  output.copy_(output_buffer); // 复制output_buffer到output
  output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});  // (B/step, step, C2, h, w)->(B, C2, h, w)

  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); // (B/step, step, C1, H, W)->(B,  C1, H, W)
  offset = offset.view(
      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});  // (B/step, step, deform_group*2*kh*kw, h, w)->(B, deform_group*2*kh*kw, h, w)

  if (batch == 0) {
    output = output.view({nOutputPlane, outputHeight, outputWidth});
    input = input.view({nInputPlane, inputHeight, inputWidth});
    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
  }
}

上文代码中的函数deformable_im2col(),变形卷积前向传播重点代码,实现邻域特征的汇总

void deformable_im2col(Tensor data_im, Tensor data_offset, const int channels,//输入特征 data_im:(step, C1, H, W), 偏置data_offset:(step, deform_group*2*kw*kh, h, w), channels: C1
                       const int height, const int width, const int ksize_h, // height=H,  width: W,
                       const int ksize_w, const int pad_h, const int pad_w, 
                       const int stride_h, const int stride_w, 
                       const int dilation_h, const int dilation_w,
                       const int parallel_imgs, const int deformable_group, // 通道数parallel_imgs=step
                       Tensor data_col) { //  data_col:(C1*kw*kh, step*h*w)
  // num_axes should be smaller than block size
  // todo: check parallel_imgs is correctly passed in
    //获取columns的高和宽
  int height_col =
      (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; // h
  int width_col =
      (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; // w
  int num_kernels = channels * height_col * width_col * parallel_imgs; // C1*h*w*step
  int channel_per_deformable_group = channels / deformable_group; // C1/deform_group, 是指每个group占据多少通道数

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      data_im.scalar_type(), "deformable_im2col_gpu", ([&] {               //[&]匿名函数:用到的任何外部变量都隐式按引用捕获
        const scalar_t *data_im_ = data_im.data_ptr();          // 获取指针,地址位置
        const scalar_t *data_offset_ = data_offset.data_ptr();
        scalar_t *data_col_ = data_col.data_ptr();

          // 用在cuda上的内核函数,后文分析
        deformable_im2col_gpu_kernel<<>>(
            num_kernels, data_im_, data_offset_, height, width, ksize_h,
            ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
            channel_per_deformable_group, parallel_imgs, channels,
            deformable_group, height_col, width_col, data_col_);
      }));
  AT_CUDA_CHECK(cudaGetLastError());
}

函数deformable_im2col_gpu_kernel()在deform_conv_cuda_kernel.cuh文件中

template 
__global__ void deformable_im2col_gpu_kernel(
    const int n, const T *data_im, const T *data_offset, const int height,  // n=C1*h*w*step,;data_im:(step, C1, H, W)的起始地址,为输入特征,;data_offset:(step, deform_group*2*kw*kh, h, w)的起始地址,为坐标偏置; height=H
    const int width, const int kernel_h, const int kernel_w, const int pad_h, //输入特征宽 width=W 
    const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    const int channel_per_deformable_group, const int batch_size, //  batch_size=step
    const int num_channels, const int deformable_group, const int height_col, //num_channels=C1, 输出特征的高height_col=h
    const int width_col, T *data_col) { //width_col=w,  data_col:(C1*kw*kh, step*h*w)的起始地址,为columns
  CUDA_1D_KERNEL_LOOP(index, n) {  // index从0到n进行遍历
    // index index of output matrix
      // index为0到C1*step*h*w遍历,可以理解为columns上特征点的索引,但比columns维度少了kh*kw,所以一个index对应一个kh*kw
    const int w_col = index % width_col;  // 特征点的w坐标
    const int h_col = (index / width_col) % height_col;   // 特征点的h坐标
    const int b_col = (index / width_col / height_col) % batch_size;  // 特征点所在的batch数
    const int c_im = (index / width_col / height_col) / batch_size; // 特征点对应输入图imput上的通道数,输出特征columns和输入特征data_im上的特征点具有对应关系
    const int c_col = c_im * kernel_h * kernel_w;  // 特征点对应输出特征图columns的通道数,因为columns的通道为C1*kw*kh,而data_im的通道为C1

    // compute deformable group index
    const int deformable_group_index = c_im / channel_per_deformable_group;

    const int h_in = h_col * stride_h - pad_h;    // 输入特征图上的特征点坐标映射在输入特征图上的高度值h
    const int w_in = w_col * stride_w - pad_w; // 输入特征图上的特征点坐标映射在输入特征图上的宽度值w
      // 获得该特征点在columns上的地址
    T *data_col_ptr =  
        data_col +
        ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
      //  获得该特征点在data_im上的对应通道的特征面上的起始地址,因为只考虑了b_col和c_im,还没有包括特征点具体的高和宽
    const T *data_im_ptr = 
        data_im + (b_col * num_channels + c_im) * height * width;
      // 只考虑了b_col和deformable_group_index,还没有包括特征点对应的坐标和偏置
    const T *data_offset_ptr =
        data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
                          kernel_h * kernel_w * height_col * width_col;
// 因为一个index对应一个kh*kw,所以要进一步遍历kh*kw
    for (int i = 0; i < kernel_h; ++i) {
      for (int j = 0; j < kernel_w; ++j) {
        const int data_offset_h_ptr =
            ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;  // 得到完整的偏置高度地址
        const int data_offset_w_ptr =
            ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; // 得到完整的偏置宽度地址
        const T offset_h = data_offset_ptr[data_offset_h_ptr]; // 偏置高度
        const T offset_w = data_offset_ptr[data_offset_w_ptr]; // 偏置宽度
        T val = static_cast(0);
        const T h_im = h_in + i * dilation_h + offset_h;    //对应的输入特征图上该点邻域点的偏置高度值h
        const T w_im = w_in + j * dilation_w + offset_w; //对应的输入特征图上该点邻域点的偏置宽度值w
        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
           // 因为h_im和w_im为小数, 所以需要插值计算,后文分析
          val = deformable_im2col_bilinear(data_im_ptr, width, height, width,
                                           h_im, w_im);
        *data_col_ptr = val;
        data_col_ptr += batch_size * height_col * width_col;  // 因为columns的维度为(C1*kw*kh, step*h*w),kw和kh在step*h*w前面,所以每个偏移点中间相隔step*h*w个点
      }
    }
  }
}

函数deformable_im2col_bilinear()根据偏置获取特征点,双线性插值,比较简单

template 
__device__ T deformable_im2col_bilinear(const T *input, const int data_width,  // input:输入图当前通道特征面的起始地址
                                        const int height, const int width, T h,
                                        T w) {
  if (h <= -1 || height <= h || w <= -1 || width <= w) {
    return 0;
  }

    // 上下取整
  int h_low = floor(h);
  int w_low = floor(w);
  int h_high = h_low + 1;
  int w_high = w_low + 1;

  T lh = h - h_low;
  T lw = w - w_low;
  T hh = 1 - lh, hw = 1 - lw;

  T v1 = 0;
  if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
  T v2 = 0;
  if (h_low >= 0 && w_high <= width - 1)
    v2 = input[h_low * data_width + w_high];
  T v3 = 0;
  if (h_high <= height - 1 && w_low >= 0)
    v3 = input[h_high * data_width + w_low];
  T v4 = 0;
  if (h_high <= height - 1 && w_high <= width - 1)
    v4 = input[h_high * data_width + w_high];

  T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;

  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  return val;
}

到此前向传播分析结束,下文分析梯度反传

梯度反传

需要对输入特征Input偏置offset以及参数权重weight进行梯度反传计算

输入特征Input和偏置offset的梯度计算

梯度反传是由前向传播计算的gradOutput进行反向计算梯度

查看deform_conv.cpp中的函数deform_conv_backward_input()

该函数调用了deform_conv_forward_cuda()函数

void deform_conv_backward_input(Tensor input,  // (B, C1, H, W)
                                Tensor offset, // 
                                Tensor gradOutput,
                                Tensor gradInput, Tensor gradOffset,
                                Tensor weight, Tensor columns, int kW, int kH,
                                int dW, int dH, int padW, int padH,
                                int dilationW, int dilationH, int group,
                                int deformable_group, int im2col_step) {
  if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
    CHECK_CUDA_INPUT(input);
    CHECK_CUDA_INPUT(offset);
    CHECK_CUDA_INPUT(gradOutput);
    CHECK_CUDA_INPUT(gradInput);
    CHECK_CUDA_INPUT(gradOffset);
    CHECK_CUDA_INPUT(weight);
    CHECK_CUDA_INPUT(columns);

    deform_conv_backward_input_cuda(input, offset, gradOutput, gradInput,
                                    gradOffset, weight, columns, kW, kH, dW, dH,
                                    padW, padH, dilationW, dilationH, group,
                                    deformable_group, im2col_step);
#else
    AT_ERROR("DeformConv is not compiled with GPU support");
#endif
  } else {
    AT_ERROR("DeformConv is not implemented on CPU");
  }
}

函数deform_conv_backward_input_cuda()

进一步调用了函数DeformConvBackwardInputCUDAKernelLauncher()

void deform_conv_backward_input_cuda(Tensor input, Tensor offset,
                                     Tensor gradOutput, Tensor gradInput,
                                     Tensor gradOffset, Tensor weight,
                                     Tensor columns, int kW, int kH, int dW,
                                     int dH, int padW, int padH, int dilationW,
                                     int dilationH, int group,
                                     int deformable_group, int im2col_step) {
  DeformConvBackwardInputCUDAKernelLauncher(
      input, offset, gradOutput, gradInput, gradOffset, weight, columns, kW, kH,
      dW, dH, padW, padH, dilationW, dilationH, group, deformable_group,
      im2col_step);
}

函数DeformConvBackwardInputCUDAKernelLauncher()位于文件deform_conv_cuda.cu中,这是核心代码

void DeformConvBackwardInputCUDAKernelLauncher(
    Tensor input,  //  (B, C1, H, W)
    Tensor offset, // (B, deform_group*2*kw*kh, h, w)
    Tensor gradOutput, // (B, C2, h, w)
    Tensor gradInput,  // (B, C1, H, W)
    Tensor gradOffset, // (B, deform_group*2*kw*kh, h, w)
    Tensor weight,  // (C2, C1/group, kh, kw)
    Tensor columns,
    int kW, int kH, int dW, 
    int dH, int padW, int padH, int dilationW, int dilationH, int group,
    int deformable_group, int im2col_step) { 
  deform_conv_shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW,
                          padH, padW, dilationH, dilationW, group,
                          deformable_group);
  at::DeviceGuard guard(input.device());

  int batch = 1;

  if (input.ndimension() == 3) {
    // Force batch
    batch = 0;
    input = input.view({1, input.size(0), input.size(1), input.size(2)});
    offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
    gradOutput = gradOutput.view(
        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
  }

  long batchSize = input.size(0);  // B
  long nInputPlane = input.size(1); // C!
  long inputHeight = input.size(2);// H
  long inputWidth = input.size(3);// W

  long nOutputPlane = weight.size(0); //C2

  long outputWidth =
      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;  //h
  long outputHeight =
      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; // w

  TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});  // (B, C1, H, W)
  columns = at::zeros(
      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
      input.options());  // (C1*kw*kh, step*h*w),全为0

  // change order of grad output
  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
                                nOutputPlane, outputHeight, outputWidth}); //(B, C2, h, w)-> (B/step, step, C2, h, w)
  gradOutput.transpose_(1, 2);  // (B/step, step, C2, h, w)->(B/step, C2, step, h, w)

  gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
                              inputHeight, inputWidth});  // (B, C1, H, W)->(B/step, step, C1, H, W)
  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
                      inputHeight, inputWidth});  // (B, C1, H, W)->(B/step, step, C1, H, W)
  gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
                                deformable_group * 2 * kH * kW, outputHeight,
                                outputWidth});  // (B, deform_group*2*kw*kh, h, w)->(B/step, step, deform_group*2*kh*kw, h, w)
  offset =
      offset.view({batchSize / im2col_step, im2col_step,
                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});  // (B, deform_group*2*kw*kh, h, w)->(B/step, step, deform_group*2*kh*kw, h, w)

  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
    // divide into groups
    columns = columns.view({group, columns.size(0) / group, columns.size(1)}); // (C1*kw*kh, step*h*w)->(group, C1*kw*kh/group, step*h*w)
    weight = weight.view({group, weight.size(0) / group, weight.size(1),
                          weight.size(2), weight.size(3)});  // (C2, C1/group, kh, kw)->(group, C2/group, C1/group, kh kw)
    gradOutput = gradOutput.view(
        {gradOutput.size(0), group, gradOutput.size(1) / group,
         gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});  //(B/step, C2, step, h, w)-> (B/step, group, C2/group, step, h, w)

    for (int g = 0; g < group; g++) {
       // 根据gradOutput获取columns的梯度, 矩阵乘积的梯度计算
      columns[g] = columns[g].
                                     addmm_(weight[g].flatten(1).transpose(0, 1),  //(C2/group, C1/group, kh kw)-> (  C2/group, C1/group*kh*kw)-> (C1/group*kh*kw, C2/group)
                                     gradOutput[elt][g].flatten(1), 0.0f, 1.0f);  // (C2/group, step, h, w)->(C2/group, step*h*w)
    }

    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)}); //(group, C1*kw*kh/group, step*h*w)-> (C1*kh*kw, step*h*w)
    gradOutput = gradOutput.view(
        {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
         gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); // (B/step, group, C2/group, step, h, w)->(B/step, C2, step, h, w)

      // 计算偏置的梯度gradOffset,见后文分析
    deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,  // columns:  (C1*kh*kw, step*h*w), input[elt]: (step, C1, H, W), offset[elt]: (step, deform_group*2*kw*kh, h, w), nInputPlane:C1
                            inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
                            dilationH, dilationW, im2col_step, deformable_group,
                            gradOffset[elt]); // gradOffset[elt]: (step, deform_group*2*kw*kh, h, w)

      // 计算输入特征的梯度gradInput, 见后文分析
    deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,   //  columns:  (C1*kh*kw, step*h*w),  offset[elt]: (step, deform_group*2*kw*kh, h, w), nInputPlane:C1
                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
                      dilationW, im2col_step, deformable_group, gradInput[elt]);  // (step, C1, H, W)
  }

  gradOutput.transpose_(1, 2);
  gradOutput =
      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});

  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
  gradOffset = gradOffset.view(
      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
  offset = offset.view(
      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});

  if (batch == 0) {
    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
    input = input.view({nInputPlane, inputHeight, inputWidth});
    gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
    gradOffset =
        gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
  }
}

函数 deformable_col2im_coord(), 计算偏置的梯度gradOffset, 函数调用了 deformable_col2im_coord_gpu_kernel()内核函数

void deformable_col2im_coord(
    Tensor data_col, Tensor data_im, Tensor data_offset, const int channels, // data_col:  (C1*kh*kw, step*h*w), data_im: (step, C1, H, W), data_offset: (step, deform_group*2*kw*kh, h, w), channels:C1
    const int height, const int width, const int ksize_h, const int ksize_w,  // height=H
    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w, const int parallel_imgs, // parallel_imgs: step
    const int deformable_group, Tensor grad_offset) {  // grad_offset: (step, deform_group*2*kw*kh, h, w)
  int height_col =
      (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;  // h
  int width_col =
      (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; //w
  int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w *
                    deformable_group * parallel_imgs;  // h*w*2*kh*kw*deform_group*step
  int channel_per_deformable_group =
      channels * ksize_h * ksize_w / deformable_group;  // C1*kh*kw/deform_group

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
        const scalar_t *data_col_ = data_col.data_ptr();
        const scalar_t *data_im_ = data_im.data_ptr();
        const scalar_t *data_offset_ = data_offset.data_ptr();
        scalar_t *grad_offset_ = grad_offset.data_ptr();

        deformable_col2im_coord_gpu_kernel<<<
            GET_BLOCKS(num_kernels), THREADS_PER_BLOCK, 0,
            at::cuda::getCurrentCUDAStream()>>>(
            num_kernels, data_col_, data_im_, data_offset_, channels, height,
            width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
            dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs,
            2 * ksize_h * ksize_w * deformable_group, deformable_group,
            height_col, width_col, grad_offset_);
      }));
  AT_CUDA_CHECK(cudaGetLastError());
}

deformable_col2im_coord_gpu_kernel()内核函数

template 
__global__ void deformable_col2im_coord_gpu_kernel(
    const int n,  // h*w*2*kh*kw*deform_group*step
    const T *data_col,  // (C1*kh*kw, step*h*w)
    const T *data_im, // (step, C1, H, W)
    const T *data_offset,  //  (step, deform_group*2*kw*kh, h, w), 
    const int channels, const int height, const int width, const int kernel_h, //channels: C1
    const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
    const int stride_w, const int dilation_h, const int dilation_w,
    const int channel_per_deformable_group, const int batch_size,  // batch_size=step
    const int offset_channels,  //  deform_group*2*kw*kh
    const int deformable_group, const int height_col,  
    const int width_col, T *grad_offset) {   //grad_offset: (step, deform_group*2*kw*kh, h, w)
    // index从0到h*w*2*kh*kw*deform_group*step进行遍历,也就是遍历整个offset
  CUDA_1D_KERNEL_LOOP(index, n) {
    T val = 0;
    int w = index % width_col;
    int h = (index / width_col) % height_col;
    int c = (index / width_col / height_col) % offset_channels;
    int b = (index / width_col / height_col) / offset_channels;
    // compute the start and end of the output

    const int deformable_group_index = c / (2 * kernel_h * kernel_w);  //  c/(2*kh*kw) 表示第几个group
    const int col_step = kernel_h * kernel_w;
    int cnt = 0;
    const T *data_col_ptr = data_col + deformable_group_index *
                                           channel_per_deformable_group *
                                           batch_size * width_col * height_col;
    const T *data_im_ptr =
        data_im + (b * deformable_group + deformable_group_index) *
                      channel_per_deformable_group / kernel_h / kernel_w *
                      height * width;
    const T *data_offset_ptr =
        data_offset + (b * deformable_group + deformable_group_index) * 2 *
                          kernel_h * kernel_w * height_col * width_col;

    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;  // c

    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
         col_c += col_step) {
      const int col_pos =
          (((col_c * batch_size + b) * height_col) + h) * width_col + w;
      const int bp_dir = offset_c % 2;

      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
      int i =
          (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
      int w_out = col_pos % width_col;
      int h_out = (col_pos / width_col) % height_col;
      int w_in = w_out * stride_w - pad_w;
      int h_in = h_out * stride_h - pad_h;
      const int data_offset_h_ptr =
          (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
      const int data_offset_w_ptr =
          (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
           w_out);
      const T offset_h = data_offset_ptr[data_offset_h_ptr];
      const T offset_w = data_offset_ptr[data_offset_w_ptr];
      T inv_h = h_in + i * dilation_h + offset_h;
      T inv_w = w_in + j * dilation_w + offset_w;
      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
        inv_h = inv_w = -2;
      const T weight = get_coordinate_weight(inv_h, inv_w, height, width,
                                             data_im_ptr + cnt * height * width,
                                             width, bp_dir);
      val += weight * data_col_ptr[col_pos];
      cnt += 1;
    }

    grad_offset[index] = val;
  }
}

你可能感兴趣的:(神经网络图像识别pytorch)