caffe系列:deeplab中的插值网络层前传和反传的实现分析

一、前言

最近在torch中实现一个网络,需要用到插值层,索性就看了下deeplab中插值网络层的实现,移植到torch中,废话少说,先粗略地写个放上来。

二、双线性插值理论

暂时没空写那么详细,双线性插值的理论可以看

https://en.wikipedia.org/wiki/Bilinear_interpolation


三、插值层的前向和反向传播的实现分析

插值网络层的实现:

interp_layer.cpp文件

https://bitbucket.org/aquariusjay/deeplab-public-ver2/src/071ef5a59aad8d9e6e1f5b8dff3d7a5c984a3d3a/src/caffe/layers/interp_layer.cpp?at=master&fileviewer=file-view-default


前传和反传的CPU实现interp.cpp

https://bitbucket.org/aquariusjay/deeplab-public-ver2/src/071ef5a59aad8d9e6e1f5b8dff3d7a5c984a3d3a/src/caffe/util/interp.cpp?at=master&fileviewer=file-view-default

我这里只分析CPU的实现,GPU没有时间看

(0)网络层中参数及概览

template 
void InterpLayer::LayerSetUp(const vector*>& bottom,
      const vector*>& top) {
  InterpParameter interp_param = this->layer_param_.interp_param();
  pad_beg_ = interp_param.pad_beg();
  pad_end_ = interp_param.pad_end();
  // pad必须小于等于0,这里实际上是crop
  CHECK_LE(pad_beg_, 0) << "Only supports non-pos padding (cropping) for now";
  CHECK_LE(pad_end_, 0) << "Only supports non-pos padding (cropping) for now";
}

template 
void InterpLayer::Reshape(const vector*>& bottom,
      const vector*>& top) {
  num_ = bottom[0]->num();
  channels_ = bottom[0]->channels();
  height_in_ = bottom[0]->height();
  width_in_ = bottom[0]->width();
  // crop之后的高度
  height_in_eff_ = height_in_ + pad_beg_ + pad_end_;
  // crop之后的宽度
  width_in_eff_ = width_in_ + pad_beg_ + pad_end_;
  InterpParameter interp_param = this->layer_param_.interp_param();
  if (interp_param.has_shrink_factor() &&
      !interp_param.has_zoom_factor()) {
    // 缩小因子
    const int shrink_factor = interp_param.shrink_factor();
    // 检查是否大于等于1
    CHECK_GE(shrink_factor, 1) << "Shrink factor must be positive";
    // 计算缩小之后的大小,宽度和高度, 上取整
    height_out_ = (height_in_eff_ - 1) / shrink_factor + 1;
    width_out_ = (width_in_eff_ - 1) / shrink_factor + 1;
  } else if (interp_param.has_zoom_factor() &&
             !interp_param.has_shrink_factor()) {
    // 放大因子
    const int zoom_factor = interp_param.zoom_factor();
    // 检查是否大于等于1
    CHECK_GE(zoom_factor, 1) << "Zoom factor must be positive";
    // 这是什么鬼,为啥放大因子要减去1
    height_out_ = height_in_eff_ + (height_in_eff_ - 1) * (zoom_factor - 1);
    width_out_ = width_in_eff_ + (width_in_eff_ - 1) * (zoom_factor - 1);
  } else if (interp_param.has_height() && interp_param.has_width()) {
    // 如果直接给出输出大小
    height_out_  = interp_param.height();
    width_out_  = interp_param.width();
  } else if (interp_param.has_shrink_factor() &&
             interp_param.has_zoom_factor()) {
    // 如果给出缩放因子和放大因子
    const int shrink_factor = interp_param.shrink_factor();
    const int zoom_factor = interp_param.zoom_factor();
    CHECK_GE(shrink_factor, 1) << "Shrink factor must be positive";
    CHECK_GE(zoom_factor, 1) << "Zoom factor must be positive";
    // 先缩放,再放大之后的大小
    height_out_ = (height_in_eff_ - 1) / shrink_factor + 1;
    width_out_ = (width_in_eff_ - 1) / shrink_factor + 1;
    height_out_ = height_out_ + (height_out_ - 1) * (zoom_factor - 1);
    width_out_ = width_out_ + (width_out_ - 1) * (zoom_factor - 1);
  } else {
    LOG(FATAL);
  }
  CHECK_GT(height_in_eff_, 0) << "height should be positive";
  CHECK_GT(width_in_eff_, 0) << "width should be positive";
  CHECK_GT(height_out_, 0) << "height should be positive";
  CHECK_GT(width_out_, 0) << "width should be positive";
  // reshape输出的blob
  top[0]->Reshape(num_, channels_, height_out_, width_out_);
}

template 
void InterpLayer::Forward_cpu(const vector*>& bottom,
      const vector*>& top) {
  // 调用具体的实现
  caffe_cpu_interp2(num_ * channels_,
    bottom[0]->cpu_data(), - pad_beg_, - pad_beg_, height_in_eff_, width_in_eff_, height_in_, width_in_,
    top[0]->mutable_cpu_data(), 0, 0, height_out_, width_out_, height_out_, width_out_);
}

template 
void InterpLayer::Backward_cpu(const vector*>& top,
      const vector& propagate_down, const vector*>& bottom) {
  if (!propagate_down[0]) { return; }
  // 调用具体的实现
  caffe_set(bottom[0]->count(), Dtype(0), bottom[0]->mutable_cpu_diff());
  caffe_cpu_interp2_backward(num_ * channels_,
    bottom[0]->mutable_cpu_diff(), - pad_beg_, - pad_beg_, height_in_eff_, width_in_eff_, height_in_, width_in_,
    top[0]->cpu_diff(), 0, 0, height_out_, width_out_, height_out_, width_out_);
}




(1)前向

注释有部分是英文的,先贴上来,后面有空再解释

// Bi-linear interpolation
// https://en.wikipedia.org/wiki/Bilinear_interpolation
// IN : [channels height1 width1] cropped from a bigger [Height1 Width1] image
// OUT: [channels height2 width2] cropped from a bigger [Height2 Width2] image
template 
void caffe_cpu_interp2(const int channels,
    const Dtype *data1, const int x1, const int y1, const int height1, const int width1, const int Height1, const int Width1,
    Dtype *data2, const int x2, const int y2, const int height2, const int width2, const int Height2, const int Width2) {
    // 检查是否合法
    CHECK(x1 >= 0 && y1 >= 0 && height1 > 0 && width1 > 0 && x2 >= 0 && y2 >= 0 && height2 > 0 && width2 > 0);
    CHECK(Width1 >= width1 + x1 && Height1 >= height1 + y1 && Width2 >= width2 + x2 && Height2 >= height2 + y2);
    // 参数解释
    // channels 输入数据的个数
    // data1 输入数据指针
    // x1 输入数据w偏移量
    // y1 输入数据h偏移量
    // height1 输入数据crop之后的高度
    // width1  输入数据crop之后的宽度
    // Height1 输入数据的原始高度
    // Width1 输入数据的原始宽度
    // data2 输出的数据指针
    // x2 输出数据w偏移
    // y2 输出数据h偏移
    // height2 输出数据crop之后的高度
    // width2 输出数据crop之后的宽度
    // Height2 输出数据的原始高度
    // Width2 输出数据的原始宽度

    // special case: just copy
    if (height1 == height2 && width1 == width2) 
    {// 输入和输出一样大小的
        for (int h2 = 0; h2 < height2; ++h2) 
        {
            const int h1 = h2;
            for (int w2 = 0; w2 < width2; ++w2) 
            {
                const int w1 = w2;
                if (packed) 
                {
                    // what is packed?
                    const Dtype* pos1 = &data1[channels * ((y1 + h1) * Width1 + (x1 + w1))];
                    Dtype* pos2 = &data2[channels * ((y2 + h2) * Width2 + (x2 + w2))];
                    for (int c = 0; c < channels; ++c) 
                    {
                        pos2[0] = pos1[0];
                        pos1++;
                        pos2++;
                    }
                }
                else 
                {
                    // normal situation
                    const Dtype* pos1 = &data1[(y1 + h1) * Width1 + (x1 + w1)];
                    Dtype* pos2 = &data2[(y2 + h2) * Width2 + (x2 + w2)];
                    for (int c = 0; c < channels; ++c) 
                    {
                        pos2[0] = pos1[0];
                        pos1 += Width1 * Height1;
                        pos2 += Width2 * Height2;
                    }
                }
            }
        }
        return;
    }

    // calculate height factor and width factor
    // input / output
    const float rheight = (height2 > 1) ? static_cast(height1 - 1) / (height2 - 1) : 0.f;
    const float rwidth = (width2 > 1) ? static_cast(width1 - 1) / (width2 - 1) : 0.f;
    for (int h2 = 0; h2 < height2; ++h2) {// calculate h1 and w1 according to h2 and w2
        // calculate height in the input image according to h factor
        const float h1r = rheight * h2;
        // convert h1r to int
        const int h1 = h1r;
        // h1p indicates whether the pos is valid
        const int h1p = (h1 < height1 - 1) ? 1 : 0;
        // h0lambda and h1lambda indicate two residuals in wiki equation
        const Dtype h1lambda = h1r - h1;
        const Dtype h0lambda = Dtype(1.) - h1lambda;
        for (int w2 = 0; w2 < width2; ++w2) 
        {
            // calculate width in the input image according to w factor
            const float w1r = rwidth * w2;
            // convert w1r to int
            const int w1 = w1r;
            // w1p indicates whether the pos is valid
            const int w1p = (w1 < width1 - 1) ? 1 : 0;
            // w0lambda and w1lambda indicate two residuals in wiki equation
            const Dtype w1lambda = w1r - w1;
            const Dtype w0lambda = Dtype(1.) - w1lambda;
            if (packed) 
            {
                const Dtype* pos1 = &data1[channels * ((y1 + h1) * Width1 + (x1 + w1))];
                Dtype* pos2 = &data2[channels * ((y2 + h2) * Width2 + (x2 + w2))];
                for (int c = 0; c < channels; ++c) 
                {
                    pos2[0] =
                    h0lambda * (w0lambda * pos1[0]            + w1lambda * pos1[channels * w1p]) + 
                    h1lambda * (w0lambda * pos1[channels * h1p * Width1] + w1lambda * pos1[channels * (h1p * Width1 + w1p)]);
                    pos1++;
                    pos2++;
                }
            }
            else 
            {
                const Dtype* pos1 = &data1[(y1 + h1) * Width1 + (x1 + w1)];
                Dtype* pos2 = &data2[(y2 + h2) * Width2 + (x2 + w2)];
                for (int c = 0; c < channels; ++c) {// visit all channels
                    // calculate bi-linear interpolation according to wiki
                    pos2[0] =
                    h0lambda * (w0lambda * pos1[0]            + w1lambda * pos1[w1p]) + 
                    h1lambda * (w0lambda * pos1[h1p * Width1] + w1lambda * pos1[h1p * Width1 + w1p]);
                    pos1 += Width1 * Height1;
                    pos2 += Width2 * Height2;
                }
            }
        }
    }
}




(2)反向

// Backward (adjoint) operation 1 <- 2 (accumulates)
template 
void caffe_cpu_interp2_backward(const int channels,
    Dtype *data1, const int x1, const int y1, const int height1, const int width1, const int Height1, const int Width1,
    const Dtype *data2, const int x2, const int y2, const int height2, const int width2, const int Height2, const int Width2) {
    // check parameters
    CHECK(x1 >= 0 && y1 >= 0 && height1 > 0 && width1 > 0 && x2 >= 0 && y2 >= 0 && height2 > 0 && width2 > 0);
    CHECK(Width1 >= width1 + x1 && Height1 >= height1 + y1 && Width2 >= width2 + x2 && Height2 >= height2 + y2);
    
    // special case: same-size matching grids
    if (height1 == height2 && width1 == width2) 
    {
        for (int h2 = 0; h2 < height2; ++h2) 
        {
            const int h1 = h2;
            for (int w2 = 0; w2 < width2; ++w2) 
            {
                const int w1 = w2;
                if (packed) 
                {
                    Dtype* pos1 = &data1[channels * ((y1 + h1) * Width1 + (x1 + w1))];
                    const Dtype* pos2 = &data2[channels * ((y2 + h2) * Width2 + (x2 + w2))];
                    for (int c = 0; c < channels; ++c) 
                    {
                        pos1[0] += pos2[0];
                        pos1++;
                        pos2++;
                    }
                }
                else 
                {
                    Dtype* pos1 = &data1[(y1 + h1) * Width1 + (x1 + w1)];
                    const Dtype* pos2 = &data2[(y2 + h2) * Width2 + (x2 + w2)];
                    for (int c = 0; c < channels; ++c) 
                    {
                        // acumulate gradients
                        pos1[0] += pos2[0];
                        pos1 += Width1 * Height1;
                        pos2 += Width2 * Height2;
                    }
                }
            }
        }
        return;
    }

    // calculate w/h factor
    // input image / output image
    const float rheight = (height2 > 1) ? static_cast(height1 - 1) / (height2 - 1) : 0.f;
    const float rwidth = (width2 > 1) ? static_cast(width1 - 1) / (width2 - 1) : 0.f;
    for (int h2 = 0; h2 < height2; ++h2) 
    {
        const float h1r = rheight * h2;
        const int h1 = h1r;
        const int h1p = (h1 < height1 - 1) ? 1 : 0;
        const Dtype h1lambda = h1r - h1;
        const Dtype h0lambda = Dtype(1.) - h1lambda;
        for (int w2 = 0; w2 < width2; ++w2) 
        {
            const float w1r = rwidth * w2;
            const int w1 = w1r;
            const int w1p = (w1 < width1 - 1) ? 1 : 0;
            const Dtype w1lambda = w1r - w1;
            const Dtype w0lambda = Dtype(1.) - w1lambda;
            if (packed) 
            {
                Dtype* pos1 = &data1[channels * ((y1 + h1) * Width1 + (x1 + w1))];
                const Dtype* pos2 = &data2[channels * ((y2 + h2) * Width2 + (x2 + w2))];
                for (int c = 0; c < channels; ++c) 
                {
                    pos1[0] += h0lambda * w0lambda * pos2[0];
                    pos1[channels * w1p] += h0lambda * w1lambda * pos2[0];
                    pos1[channels * h1p * Width1] += h1lambda * w0lambda * pos2[0];
                    pos1[channels * (h1p * Width1 + w1p)] += h1lambda * w1lambda * pos2[0];
                    pos1++;
                    pos2++;
                }
            }
            else 
            {
                Dtype* pos1 = &data1[(y1 + h1) * Width1 + (x1 + w1)];
                const Dtype* pos2 = &data2[(y2 + h2) * Width2 + (x2 + w2)];
                for (int c = 0; c < channels; ++c) 
                {
                    // acumulate gradients from scaled pixels
                    pos1[0] += h0lambda * w0lambda * pos2[0];
                    pos1[w1p] += h0lambda * w1lambda * pos2[0];
                    pos1[h1p * Width1] += h1lambda * w0lambda * pos2[0];
                    pos1[h1p * Width1 + w1p] += h1lambda * w1lambda * pos2[0];
                    pos1 += Width1 * Height1;
                    pos2 += Width2 * Height2;
                }
            }
        }
    }
}

四、总结:

实际上这里实现的双线性插值还挺简单的,就是根据维基百科上面的实现来实现前传的,另外,前传的时候需要注意边界问题

此外反传则是将进行插值后之后的像素值反传到插值之前的位置,进行累加来实现的。


有不懂的可以留言

后来发现torch里面已经有人移植过去了,干脆帖进来


五、torch中的实现

https://github.com/torch/nn/blob/master/SpatialUpSamplingBilinear.lua

lua的接口

require 'nn.THNN'
local SpatialUpSamplingBilinear, parent =
   torch.class('nn.SpatialUpSamplingBilinear', 'nn.Module')

--[[
Applies a 2D bilinear up-sampling over an input image composed of several
input planes.

The Y and X dimensions are assumed to be the last 2 tensor dimensions.  For
instance, if the tensor is 4D, then dim 3 is the y dimension and dim 4 is the x.

scale_factor is assumed to be a positive integer. 
owidth  = (width-1)*(scale_factor-1) + width
oheight  = (height-1)*(scale_factor-1) + height

Alternatively, owidth and oheight can be directly provided as input.
--]]

function SpatialUpSamplingBilinear:__init(params)
   parent.__init(self)

   self.owidth, self.oheight, self.scale_factor = nil, nil, nil
   if torch.type(params) == 'table' then
      self.owidth, self.oheight = params.owidth, params.oheight
   else
      self.scale_factor = params   
      if self.scale_factor < 1 then
         error('scale_factor must be greater than 1')
      end
      if math.floor(self.scale_factor) ~= self.scale_factor then
         error('scale_factor must be integer')
      end
   end
   self.inputSize = torch.LongStorage(4)
   self.outputSize = torch.LongStorage(4)
end

local function makeContiguous(self, input, gradOutput)
   if not input:isContiguous() then
      self._input = self._input or input.new()
      self._input:resizeAs(input):copy(input)
      input = self._input
   end
   if gradOutput then
      if not gradOutput:isContiguous() then
         self._gradOutput = self._gradOutput or gradOutput.new()
         self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
         gradOutput = self._gradOutput
      end
   end
   return input, gradOutput
end

function SpatialUpSamplingBilinear:setSize(input)
   local xdim = input:dim()
   local ydim = xdim - 1
   for i = 1, input:dim() do
      self.inputSize[i] = input:size(i)
      self.outputSize[i] = input:size(i)
   end
   if self.scale_factor ~= nil then
      self.outputSize[ydim] = self.outputSize[ydim] * self.scale_factor
      self.outputSize[xdim] = self.outputSize[xdim] * self.scale_factor
   else
      self.outputSize[ydim] = self.oheight
      self.outputSize[xdim] = self.owidth
   end
end

function SpatialUpSamplingBilinear:updateOutput(input)
   assert(input:dim() == 4 or input:dim()==3,
            'SpatialUpSamplingBilinear only supports 3D or 4D tensors' )
   input = makeContiguous(self, input)
   local inputwas3D = false
   if input:dim() == 3 then
      input=input:view(-1, input:size(1), input:size(2), input:size(3))
      inputwas3D = true
   end
   local xdim = input:dim()
   local ydim = xdim - 1
   self:setSize(input)
   input.THNN.SpatialUpSamplingBilinear_updateOutput(
      input:cdata(),
      self.output:cdata(),
      self.outputSize[ydim],
      self.outputSize[xdim]
   )
   if inputwas3D then
      input = input:squeeze(1)
      self.output = self.output:squeeze(1)
   end
   return self.output
end

function SpatialUpSamplingBilinear:updateGradInput(input, gradOutput)
   assert(input:dim() == 4 or input:dim()==3,
            'SpatialUpSamplingBilinear only support 3D or 4D tensors' )
   assert(input:dim() == gradOutput:dim(),
	  'Input and gradOutput should be of same dimension' )
   input, gradOutput = makeContiguous(self, input, gradOutput)
   local inputwas3D = false
   if input:dim() == 3 then
      input = input:view(-1, input:size(1), input:size(2), input:size(3))
      gradOutput = gradOutput:view(-1, gradOutput:size(1), gradOutput:size(2),
				   gradOutput:size(3))
      inputwas3D = true
   end
   local xdim = input:dim()
   local ydim = xdim - 1
   self.gradInput:resizeAs(input)   
   input.THNN.SpatialUpSamplingBilinear_updateGradInput(
      gradOutput:cdata(),
      self.gradInput:cdata(),
      input:size(1),
      input:size(2),
      input:size(3),
      input:size(4),
      self.outputSize[ydim],
      self.outputSize[xdim]
   )
   if inputwas3D then
      input = input:squeeze(1)
      gradOutput = gradOutput:squeeze(1)
      self.gradInput = self.gradInput:squeeze(1)
   end
   return self.gradInput
end


function SpatialUpSamplingBilinear:__tostring__()
   local s
   if self.scale_factor ~= nil then
      s = string.format('%s(%d)', torch.type(self), self.scale_factor)
   else
      s = string.format('%s(%d, %d)', 
         torch.type(self), self.oheight, self.owidth)
   end
   return s
end


底层的c实现

https://raw.githubusercontent.com/torch/nn/master/lib/THNN/generic/SpatialUpSamplingBilinear.c

// Adapted from interp.cpp from Caffe util by Pauline Luc
// Originally developed by George Papandreou

#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/SpatialUpSamplingBilinear.c"
#else

static inline void THNN_(SpatialUpSamplingBilinear_shapeCheck)
     (THTensor *input, THTensor *gradOutput,
      int nBatch, int nChannels,
      int inputHeight, int inputWidth,
      int outputHeight, int outputWidth) {
  THArgCheck(inputHeight > 0 && inputWidth > 0
	     && outputHeight > 0 && outputWidth > 0, 2,
	     "input and output sizes should be greater than 0,"
	     " but got input (H: %d, W: %d) output (H: %d, W: %d)",
	     inputHeight, inputWidth, outputHeight, outputWidth);
  if (input != NULL) {
    THNN_ARGCHECK(input->nDimension == 4, 2, input,
		  "4D input tensor expected but got: %s");
  }

  if (gradOutput != NULL) {
    THNN_CHECK_DIM_SIZE(gradOutput, 4, 0, nBatch);
    THNN_CHECK_DIM_SIZE(gradOutput, 4, 1, nChannels);
    THNN_CHECK_DIM_SIZE(gradOutput, 4, 2, outputHeight);
    THNN_CHECK_DIM_SIZE(gradOutput, 4, 3, outputWidth);
  }
}

void THNN_(SpatialUpSamplingBilinear_updateOutput)(
    THNNState *state,
    THTensor *input,
    THTensor *output,
    int outputHeight,
    int outputWidth){

  int nbatch = THTensor_(size)(input, 0);
  int channels = THTensor_(size)(input, 1);
  int inputHeight = THTensor_(size)(input, 2);
  int inputWidth = THTensor_(size)(input, 3);

  THNN_(SpatialUpSamplingBilinear_shapeCheck)
    (input, NULL,
     nbatch, channels,
     inputHeight, inputWidth,
     outputHeight, outputWidth);

  input = THTensor_(newContiguous)(input);
  THTensor_(resize4d)(output, 
		      THTensor_(size)(input, 0), 
		      THTensor_(size)(input, 1), 
		      outputHeight, outputWidth);
  THTensor_(zero)(output);
  real *idata = THTensor_(data)(input);
  real *odata = THTensor_(data)(output);
  channels = nbatch * channels;
  THAssert(inputHeight > 0 && inputWidth > 0 && outputHeight > 0 && outputWidth > 0);
  // special case: just copy
  if (inputHeight == outputHeight && inputWidth == outputWidth) {
    for (int h2 = 0; h2 < outputHeight; ++h2) {
      const int h1 = h2;
      for (int w2 = 0; w2 < outputWidth; ++w2) {
        const int w1 = w2;
        const real* pos1 = &idata[h1 * inputWidth + w1];
        real* pos2 = &odata[h2 * outputWidth + w2];
        for (int c = 0; c < channels; ++c) {
          pos2[0] = pos1[0];
          pos1 += inputWidth * inputHeight;
          pos2 += outputWidth * outputHeight;
        }
      }
    }
    return;
  }
  const float rheight =(outputHeight > 1) ? (float)(inputHeight - 1)/(outputHeight - 1) : 0.f;
  const float rwidth = (outputWidth > 1) ? (float)(inputWidth - 1) / (outputWidth - 1) : 0.f;
  for (int h2 = 0; h2 < outputHeight; ++h2) {
    const float h1r = rheight * h2;
    const int h1 = h1r;
    const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
    const real h1lambda = h1r - h1;
    const real h0lambda = (real)1. - h1lambda;
    for (int w2 = 0; w2 < outputWidth; ++w2) {
      const float w1r = rwidth * w2;
      const int w1 = w1r;
      const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
      const real w1lambda = w1r - w1;
      const real w0lambda = (real)1. - w1lambda;
      const real* pos1 = &idata[h1 * inputWidth + w1];
      real* pos2 = &odata[h2 * outputWidth + w2];
      for (int c = 0; c < channels; ++c) {
        pos2[0] = h0lambda * (w0lambda * pos1[0]+ w1lambda * pos1[w1p])
                  + h1lambda * (w0lambda * pos1[h1p * inputWidth]
                  + w1lambda * pos1[h1p * inputWidth + w1p]);
        pos1 += inputWidth * inputHeight;
        pos2 += outputWidth * outputHeight;
      }
    }
  }
  THTensor_(free)(input);
}

void THNN_(SpatialUpSamplingBilinear_updateGradInput)(
    THNNState *state,
    THTensor *gradOutput,
    THTensor *gradInput,
    int nbatch,
    int channels,
    int inputHeight,
    int inputWidth,
    int outputHeight,
    int outputWidth){

  THNN_(SpatialUpSamplingBilinear_shapeCheck)
    (NULL, gradOutput,
     nbatch, channels,
     inputHeight, inputWidth,
     outputHeight, outputWidth);

  THTensor_(resize4d)(gradInput, nbatch, channels, inputHeight, inputWidth);
  THTensor_(zero)(gradInput);
  gradOutput = THTensor_(newContiguous)(gradOutput);
  real *data1 = THTensor_(data)(gradInput);
  real *data2 = THTensor_(data)(gradOutput);
  channels = nbatch * channels;

  // special case: same-size matching grids
  if (inputHeight == outputHeight && inputWidth == outputWidth) {
    for (int h2 = 0; h2 < outputHeight; ++h2) {
      const int h1 = h2;
      for (int w2 = 0; w2 < outputWidth; ++w2) {
        const int w1 = w2;
        real* pos1 = &data1[h1 * inputWidth + w1];
        const real* pos2 = &data2[h2 * outputWidth + w2];
        for (int c = 0; c < channels; ++c) {
          pos1[0] += pos2[0];
          pos1 += inputWidth * inputHeight;
          pos2 += outputWidth * outputHeight;
        }
      }
    }
    return;
  }
  const float rheight =(outputHeight > 1) ? (float)(inputHeight - 1)/(outputHeight - 1) : 0.f;
  const float rwidth = (outputWidth > 1) ? (float)(inputWidth - 1)/(outputWidth - 1) : 0.f;
  for (int h2 = 0; h2 < outputHeight; ++h2) {
    const float h1r = rheight * h2;
    const int h1 = h1r;
    const int h1p = (h1 < inputHeight - 1) ? 1 : 0;
    const real h1lambda = h1r - h1;
    const real h0lambda = (real)1. - h1lambda;
    for (int w2 = 0; w2 < outputWidth; ++w2) {
      const float w1r = rwidth * w2;
      const int w1 = w1r;
      const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
      const real w1lambda = w1r - w1;
      const real w0lambda = (real)1. - w1lambda;
      real* pos1 = &data1[h1 * inputWidth + w1];
      const real* pos2 = &data2[h2 * outputWidth + w2];
      for (int c = 0; c < channels; ++c) {
        pos1[0] += h0lambda * w0lambda * pos2[0];
        pos1[w1p] += h0lambda * w1lambda * pos2[0];
        pos1[h1p * inputWidth] += h1lambda * w0lambda * pos2[0];
        pos1[h1p * inputWidth + w1p] += h1lambda * w1lambda * pos2[0];
        pos1 += inputWidth * inputHeight;
        pos2 += outputWidth * outputHeight;
      }
    }
  }
  THTensor_(free)(gradOutput);
}

#endif







你可能感兴趣的:(caffe)