Scale-wise Convolution for Image Restoration论文理解

这是一篇AAAI2020的文章,这篇论文的核心思想是根据尺度不变性(尺度的变化不会改变主要特征),在残差块中增加不同尺度的特征的融合。可以概括为:提取特征金字塔,尺度变换+特征融合,跳跃连接,输出。

以论文中的图为例,具体操作可表述如下:

Scale-wise Convolution for Image Restoration论文理解_第1张图片

  • 输入图像,经过卷积生成一个特征金字塔,该特征金字塔经过残差块生成一个新的特征金字塔(也可以通过多个残差块生成多个特征金字塔),最后再进行卷积。
  • 下方为一个跳跃连接,即将输入的图像卷积一次。
  • 、最后相加,经过Pixel Shuffle输出。

结合代码数一下残差块的工作过程:

Scale-wise Convolution for Image Restoration论文理解_第2张图片Scale-wise Convolution for Image Restoration论文理解_第3张图片

 从上面两张图看以看出,特征金字塔在残差块中的尺度变换过程,就是高分辨率变低分辨率,低分辨率变高分辨率,到达同一分辨率后将特征进行融合,然后生成新的特征金字塔。

从代码上看(https://github.com/ychfan/scn/blob/master/models/scn.py),残差块中的操作集中在BLOCK类中, 我们不妨假设,网络输入的特征为F,在经过Head类(进行下采样操作)后,网络的特征列表为x_list = [F,1/2F,1/4F]。首先看到self.body,其功能是对输入进行conv2d+bn+conv2d的操作,可以看到这是一个残差连接的前半部分的操作(与输入连接之前),现在的特征是res_list = [F,1/2F,1/4F]。self.down与self.up是进行下采样和上采样操作,只对res_list中的特定元素操作,用与尺度变换。down_res_list = [F] + [1/2F] + [1/4F] = [F,1/2F,1/4F],up_res_list = [F,1/2F] + [1/4F] =  [F,1/2F,1/4F]。最后所有四个列表为:

x_list = [F,1/2F,1/4F],res_list = [F,1/2F,1/4F],down_res_list = [F,1/2F,1/4F],up_res_list = [F,1/2F,1/4F],四中特征的尺度变为一样的了。但是其中间的尺度进行过上下变换,为什么可以这样做?答案就是尺度不变性。最后将四个特征列表中的元素对应相加,获得新的特征金字塔,同时完成了残差连接的操作。上图左的意思就是尺度变换,进行融合,得到新的特征;上图右的交叉连接说的也是尺度大小变换。

class Block(nn.Module):

  def __init__(self,
               num_residual_units,
               kernel_size,
               width_multiplier=1,
               weight_norm=torch.nn.utils.weight_norm,
               res_scale=1):
    super(Block, self).__init__()
    body = []
    conv = weight_norm(  //conv2d+bn
        nn.Conv2d(
            num_residual_units,
            int(num_residual_units * width_multiplier),
            kernel_size,
            padding=kernel_size // 2))
    init.constant_(conv.weight_g, 2.0)
    init.zeros_(conv.bias)
    body.append(conv)
    body.append(nn.ReLU(True))//激活
    conv = weight_norm(//conv2d+bn
        nn.Conv2d(
            int(num_residual_units * width_multiplier),
            num_residual_units,
            kernel_size,
            padding=kernel_size // 2))
    init.constant_(conv.weight_g, res_scale)
    init.zeros_(conv.bias)
    body.append(conv)

    self.body = nn.Sequential(*body)

    down = []
    down.append(
        weight_norm(nn.Conv2d(num_residual_units, num_residual_units, 1)))
    down.append(nn.UpsamplingBilinear2d(scale_factor=0.5))
    self.down = nn.Sequential(*down)

    up = []
    up.append(weight_norm(nn.Conv2d(num_residual_units, num_residual_units, 1)))
    up.append(nn.UpsamplingBilinear2d(scale_factor=2.0))
    self.up = nn.Sequential(*up)

  def forward(self, x_list):
    res_list = [self.body(x) for x in x_list]
    down_res_list = [res_list[0]] + [self.down(x) for x in res_list[:-1]]
    up_res_list = [self.up(x) for x in res_list[1:]] + [res_list[-1]]
    x_list = [
        x + r + d + u
        for x, r, d, u in zip(x_list, res_list, down_res_list, up_res_list)
    ]
    return x_list

 

你可能感兴趣的:(深度学习)