Image Super-Resolution Using Very Deep Residual Channel Attention Networks

一,贡献

贡献有三点

  • 提出了一个非常深的residual channel attention networks (RCAN),用于超分辨率问题
  • 提出了residual in residual (RIR)结构,用来构建非常深的能有效训练的网络,long skip 和short skip连接
  • 使用channel attention(CA),来rescale feature

二,方法

1,网络架构Image Super-Resolution Using Very Deep Residual Channel Attention Networks_第1张图片

输入低分辨率图片,首先过一个卷积层
在这里插入图片描述
F0为浅层特征,然后过一个RIR模块来提取深层次特征
在这里插入图片描述
然后使用上采样模块
在这里插入图片描述
然后过the reconstruction layer
在这里插入图片描述
损失为L1 loss
Image Super-Resolution Using Very Deep Residual Channel Attention Networks_第2张图片

2, Residual in Residual (RIR)

看网络架构,RIR包含G个residual groups (RG) 和a long skip connection (LSC),每个RG包含B个residual channel attention blocks (RCAB) with short skip
connection (SSC),
RCAB: 三层卷积+CA+残差

## Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
    def __init__(
        self, conv, n_feat, kernel_size, reduction,
        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(RCAB, self).__init__()
        modules_body = []
        for i in range(2):
            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
            if bn: modules_body.append(nn.BatchNorm2d(n_feat))
            if i == 0: modules_body.append(act)
        modules_body.append(CALayer(n_feat, reduction))
        self.body = nn.Sequential(*modules_body)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x)
        #res = self.body(x).mul(self.res_scale)
        res += x
        return res

3,chanel attention(CA)

Channel Attention (CA) Layer
使用通道注意力机制要考虑两点

  • 有能力学习到通道之间的非线性关系
  • Second, as multiple channel-wise features can be emphasized opposed to one-hot activation, it must learn a non-mututually-exclusive relationship.(不太理解)

先过一个全局平均池化层
Image Super-Resolution Using Very Deep Residual Channel Attention Networks_第3张图片
使用门机制
在这里插入图片描述

W D W_{D} WD为一个卷积层,使用reduction ratio r进行通道下采样, W U W_{U} WU也是一个卷积层,用ratio r来通道上采样,

进行channel 级别的rescale
Image Super-Resolution Using Very Deep Residual Channel Attention Networks_第4张图片
代码如下

class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1) 
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), ##通道下采样
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), ##通道上采样
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y

你可能感兴趣的:(超分论文)