SAN:Second-Order Attention Network for Single Image Super-Resolution

SAN:Second-Order Attention Network for Single Image Super-Resolution

论文地址:Second-Order Attention Network for Single Image Super-Resolution

代码地址:daitao/SAN: Second-order Attention Network for Single Image Super-resolution (CVPR-2019)

简介

​ 提出一种基于二阶统计信息的通道注意力机制,产生更好的表征能力,同时,模型对non-local机制也进行了优化,针对low-level任务,直接将non-local应用在整个图会导致计算量过大,于是采用了patch进行region-level的non-local机制。

现阶段问题

  1. 现阶段的基于CNN的方法大多关注在如何设计更宽更深的网络,忽略了探索中间层的特征相关性,阻碍了CNN的表示能力。

  2. SENet中的Channel Attention只关注了一阶统计量(eg. 全局池化),忽略了高于一阶的统计量,阻碍了网络的判别能力

主要贡献

  1. 设计了一种新型的可训练的二阶通道注意(SOCA)模块,通过使用二阶特征统计量进行更具鉴别性的表示,自适应地重新缩放通道特征
    • 二阶统计量比一阶统计量更具有鉴别性的表示(见1,2)
    • 协方差归一化(见1,3)对更具辨别力的表示起着至关重要的作用
  2. 采用RL-NL module:采用region level的非局部non-local操作:不仅可以捕获长距离空间上下文信息,还可以扩大感受野
    • 对于low-level任务,链接中论文经过实验表明,适当的邻域的非局部操作表现的比全局non-local表现的更好。
    • 当特征图很大的时候,使用vanilla non-local计算负担也会很大。

方法概述

SAN:Second-Order Attention Network for Single Image Super-Resolution_第1张图片

Second-order Channel Attention (SOCA)

​ 对于给定输入,将其特征reshape为 X   w i t h   C × S ,  w h e r e   s = W H X~with~C\times S,~where ~s=WH X with C×S where s=WH,计算样本的协方差矩阵
Σ = X I ˉ X T , w h e r e   I ˉ = 1 s ( I − 1 s 1 ) I   a n d   1   a r e   t h e   s × s   i d e n t i t y   m a t r i x   a n d   m a t r i x   o f   a l l   o n e s \begin{aligned} \Sigma=&\mathrm{X}\bar{\mathrm{I}}\mathrm{X}^{T},\quad\quad \mathrm{where~\bar{\mathbf{I}}}=\frac{1}{s}(\mathbf{I}-\frac{1}{s}\mathbf{1})\\ &\mathbf{I}\mathrm{~and~1~are~the~}s\times s\mathrm{~identity~matrix~and~matrix~of~all~ones} \end{aligned} Σ=XIˉXT,where Iˉ=s1(Is11)I and 1 are the s×s identity matrix and matrix of all ones

​ 由于协方差归一化能提高模型的辨别性的表征能力,因此对 Σ \Sigma Σ进行归一化。而因为 Σ \Sigma Σ是对称半正定的矩阵,其具有特征值分解(EIG)如下:
Σ = U Λ U T \Sigma=U\Lambda U^T Σ=UΛUT
U U U是一个正交矩阵, Λ = d i a g ( λ 1 , ⋅ ⋅ ⋅ , λ C ) \Lambda =diag(λ1,···,λC) Λ=diag(λ1⋅⋅⋅λC)是具有非递增阶特征值的对角矩阵,协方差归一化可以转为:
Y ^ = Σ α = U Λ α U T \hat{Y}=\Sigma^{\alpha}=U\Lambda^{\alpha} U^T Y^=Σα=UΛαUT
α < 1 \alpha < 1 α<1,会非线性的缩小特征值大于1的值,并放大那些小于1的值。在贡献的[1]参考文献中表示, α = 0.5 \alpha=0.5 α=0.5具有最好的表征能力。

计算协方差矩阵
class Covpool(Function):
    """
        Global Covariance pooling layer
    """
    @staticmethod
    def forward(ctx, input):
        x = input
        batchSize = x.data.shape[0]

        # hwc
        dim = x.data.shape[1]
        h = x.data.shape[2]
        w = x.data.shape[3]

        # s
        M = h * w
        # Σ = X I_hat X^T,而I为SxS的矩阵,所以x需要reshape为dim,M
        x = x.reshape(batchSize, dim, M)
        # I_hat=1/s(I-1/s 1)=(-1/s/s)*1+1/s*I,I and 1 are the s × s identity matrix and matrix of all ones
        I_hat = (1. / M) * torch.eye(M, M, device=x.device)+(-1. / M / M) * torch.ones(M, M, device=x.device)
        # 将I_hat转到和x的shape一样,因为存在batch,所以需要repeat
        I_hat = I_hat.view(1, M, M).repeat(batchSize, 1, 1).type(x.dtype)
        """计算协方差矩阵Σ = X I_hat X^T"""
        # y = x I_hat x^T
        # x的shape为b,c,m,所以transpose是2,3维度
        # x.bmm(I_hat) 表示 x 和 I_hat 的批量矩阵乘法
        y = x.bmm(I_hat).bmm(x.transpose(1, 2))
        # 用于反向传播
        ctx.save_for_backward(input, I_hat)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        input, I_hat = ctx.saved_tensors
        x = input
        batchSize = x.data.shape[0]
        dim = x.data.shape[1]
        h = x.data.shape[2]
        w = x.data.shape[3]
        M = h * w
        x = x.reshape(batchSize, dim, M)
        grad_input = grad_output + grad_output.transpose(1, 2)
        grad_input = grad_input.bmm(x).bmm(I_hat)
        grad_input = grad_input.reshape(batchSize, dim, h, w)
        return grad_input
基于Newton-Schulz迭代的快速矩阵归一化方法

Towards Faster Training of Global Covariance Pooling Networks by Iterative Matrix Square Root Normalization受到这篇论文的启发,文章中利用了Newton-Schulz迭代来加速协方差归一化的计算。对于 Σ 1 / 2 = U Λ 1 / 2 U T \Sigma^{1/2}=U\Lambda^{1/2} U^T Σ1/2=UΛ1/2UT,通过令 Y 0 = Σ , Z 0 = I Y_0=\Sigma,Z_0=I Y0=Σ,Z0=I,交替迭代更新如下:
Y n = 1 2 Y n − 1 ( 3 I − Z n − 1 Y n − 1 ) , Z n = 1 2 ( 3 I − Z n − 1 Y n − 1 ) Z n − 1 . ) \begin{array}{rl} \mathbf{Y}_n&=\frac12\mathbf{Y}_{n-1}(3\mathbf{I}-\mathbf{Z}_{n-1}\mathbf{Y}_{n-1}),\\ \mathbf{Z}_n&=\frac12(3\mathbf{I}-\mathbf{Z}_{n-1}\mathbf{Y}_{n-1})\mathbf{Z}_{n-1}.)\\ \end{array} YnZn=21Yn1(3IZn1Yn1),=21(3IZn1Yn1)Zn1.)

由于Newton-Schulz迭代只局部收敛,为了保证收敛性 ,首先对 Σ \Sigma Σ进行pre-norm归一化
Σ ^ = 1 t r ( Σ ) Σ \hat{\Sigma}=\frac{1}{tr(\Sigma)}\Sigma\\ Σ^=tr(Σ)1Σ
其中 t r ( Σ ) = ∑ i C λ i tr(\Sigma)=\sum_i^C\lambda_i tr(Σ)=iCλi表示 Σ \Sigma Σ的迹。在这种情况下,能推断出 ∣ ∣ Σ − I ∣ ∣ 2 ||\Sigma − I||_2 ∣∣ΣI2等于 ( Σ − I ) (\Sigma − I) (ΣI)最大奇异值。 1 − λ i ∑ i λ i 1−\frac{λ_i}{∑i λ_i} 1iλiλi小于1,满足收敛条件.

再迭代之后,采用后补偿法,补偿在pre-norm中引起的数值波动,最后得到归一化协方差矩阵
Y ^ = t r ( Σ ) Y N , N   i s   f i n a l   i t e r \hat{Y}=\sqrt{tr(\Sigma)}Y_N,N ~is~final ~iter Y^=tr(Σ) YN,N is final iter

class Sqrtm(Function):
    @staticmethod
    def forward(ctx, input, iterN):
        x = input
        batchSize = x.data.shape[0]

        dim = x.data.shape[1]
        dtype = x.dtype

        # 3I
        I3 = 3.0 * torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype)
        # 计算tr(\Sigma),乘以单位对角阵然后求和
        normA = (1.0 / 3.0) * x.mul(I3).sum(dim=1).sum(dim=1)
        # pre_norm
        A = x.div(normA.view(batchSize, 1, 1).expand_as(x))

        # 让Y,Z具有相应的输出尺寸大小
        Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad=False, device=x.device)
        Z = torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, iterN, 1, 1)
        if iterN < 2:
            ZY = 0.5 * (I3 - A)
            Y[:, 0, :, :] = A.bmm(ZY)
        else:
            """iter1"""


            # 0.5(3I-Z_N-1Y_N-1)
            ZY = 0.5 * (I3 - A)

            # Y_1=0.5Y_0(3I-Z_0Y_0)=0.5A*(I3-A)
            Y[:, 0, :, :] = A.bmm(ZY)
            Z[:, 0, :, :] = ZY
            for i in range(1, iterN - 1):
                # 3I-Z_N-1 Z_Y-1
                ZY = 0.5 * (I3 - Z[:, i - 1, :, :].bmm(Y[:, i - 1, :, :]))

                Y[:, i, :, :] = Y[:, i - 1, :, :].bmm(ZY)
                Z[:, i, :, :] = ZY.bmm(Z[:, i - 1, :, :])

            #最后一次迭代不用更新Z,直接求Y
            ZY = 0.5 * Y[:, iterN - 2, :, :].bmm(I3 - Z[:, iterN - 2, :, :].bmm(Y[:, iterN - 2, :, :]))

        # y_hat=\sqrt{ tr(\Sigma) } Y_N,后补偿
        y = ZY * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)

        ctx.save_for_backward(input, A, ZY, normA, Y, Z)
        ctx.iterN = iterN
        return y

    @staticmethod
    def backward(ctx, grad_output):
        input, A, ZY, normA, Y, Z = ctx.saved_tensors
        iterN = ctx.iterN
        x = input
        batchSize = x.data.shape[0]
        dim = x.data.shape[1]
        dtype = x.dtype
        der_postCom = grad_output * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
        der_postComAux = (grad_output * ZY).sum(dim=1).sum(dim=1).div(2 * torch.sqrt(normA))
        I3 = 3.0 * torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype)
        if iterN < 2:
            der_NSiter = 0.5 * (der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace))
        else:
            dldY = 0.5 * (der_postCom.bmm(I3 - Y[:, iterN - 2, :, :].bmm(Z[:, iterN - 2, :, :])) -
                          Z[:, iterN - 2, :, :].bmm(Y[:, iterN - 2, :, :]).bmm(der_postCom))
            dldZ = -0.5 * Y[:, iterN - 2, :, :].bmm(der_postCom).bmm(Y[:, iterN - 2, :, :])
            for i in range(iterN - 3, -1, -1):
                YZ = I3 - Y[:, i, :, :].bmm(Z[:, i, :, :])
                ZY = Z[:, i, :, :].bmm(Y[:, i, :, :])
                dldY_ = 0.5 * (dldY.bmm(YZ) -
                               Z[:, i, :, :].bmm(dldZ).bmm(Z[:, i, :, :]) -
                               ZY.bmm(dldY))
                dldZ_ = 0.5 * (YZ.bmm(dldZ) -
                               Y[:, i, :, :].bmm(dldY).bmm(Y[:, i, :, :]) -
                               dldZ.bmm(ZY))
                dldY = dldY_
                dldZ = dldZ_
            der_NSiter = 0.5 * (dldY.bmm(I3 - A) - dldZ - A.bmm(dldY))
        grad_input = der_NSiter.div(normA.view(batchSize, 1, 1).expand_as(x))
        grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1)
        for i in range(batchSize):
            grad_input[i, :, :] += (der_postComAux[i] \
                                    - grad_aux[i] / (normA[i] * normA[i])) \
                                   * torch.ones(dim, device=x.device).diag()
        return grad_input, None

region-level non local

​ 原始non-local机制参见Non-local Neural Networks,由于原始的是global 的non-local,当特征图较大,会导致计算量复杂;又根据经验表明,在合适的局部大小进行非局部操作能很好的适合low-level任务。因此论文采用了region-level non local。

​ 将图片切成四块,每一块中进行region-level non-local机制,最后在拼接在一起。

class Nonlocal_CA(nn.Module):
    def __init__(self, in_feat=64, inter_feat=32, reduction=8,sub_sample=False, bn_layer=True):
        super(Nonlocal_CA, self).__init__()
        # second-order channel attention
        self.soca=SOCA(in_feat, reduction=reduction)
        # nonlocal module
        self.non_local = (NONLocalBlock2D(in_channels=in_feat,inter_channels=inter_feat, sub_sample=sub_sample,bn_layer=bn_layer))

        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        ## divide feature map into 4 part
        batch_size,C,H,W = x.shape
        H1 = int(H / 2)
        W1 = int(W / 2)
        nonlocal_feat = torch.zeros_like(x)

        feat_sub_lu = x[:, :, :H1, :W1]
        feat_sub_ld = x[:, :, H1:, :W1]
        feat_sub_ru = x[:, :, :H1, W1:]
        feat_sub_rd = x[:, :, H1:, W1:]


        nonlocal_lu = self.non_local(feat_sub_lu)
        nonlocal_ld = self.non_local(feat_sub_ld)
        nonlocal_ru = self.non_local(feat_sub_ru)
        nonlocal_rd = self.non_local(feat_sub_rd)
        nonlocal_feat[:, :, :H1, :W1] = nonlocal_lu
        nonlocal_feat[:, :, H1:, :W1] = nonlocal_ld
        nonlocal_feat[:, :, :H1, W1:] = nonlocal_ru
        nonlocal_feat[:, :, H1:, W1:] = nonlocal_rd

        return  nonlocal_feat

    

vanilla non-local机制如下:

class _NonLocalBlockND(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian',
                 sub_sample=True, bn_layer=True):
        super(_NonLocalBlockND, self).__init__()
        assert dimension in [1, 2, 3]
        assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation']

        # print('Dimension: %d, mode: %s' % (dimension, mode))

        self.mode = mode
        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool = nn.MaxPool3d
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool = nn.MaxPool2d
            sub_sample = nn.Upsample
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool = nn.MaxPool1d
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = None
        self.phi = None
        self.concat_project = None
        # self.fc = nn.Linear(64,2304,bias=True)
        # self.sub_bilinear = nn.Upsample(size=(48,48),mode='bilinear')
        # self.sub_maxpool = nn.AdaptiveMaxPool2d(output_size=(48,48))
        if mode in ['embedded_gaussian', 'dot_product', 'concatenation']:
            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                                 kernel_size=1, stride=1, padding=0)
            self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                               kernel_size=1, stride=1, padding=0)

            if mode == 'embedded_gaussian':
                self.operation_function = self._embedded_gaussian
            elif mode == 'dot_product':
                self.operation_function = self._dot_product
            elif mode == 'concatenation':
                self.operation_function = self._concatenation
                self.concat_project = nn.Sequential(
                    nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
                    nn.ReLU()
                )
        elif mode == 'gaussian':
            self.operation_function = self._gaussian

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
            if self.phi is None:
                self.phi = max_pool(kernel_size=2)
            else:
                self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))

    def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''

        output = self.operation_function(x)
        return output

    def _embedded_gaussian(self, x):
        batch_size,C,H,W = x.shape

        # x_sub = self.sub_bilinear(x) # bilinear downsample
        # x_sub = self.sub_maxpool(x) # maxpool downsample

        ##
        # g_x = x.view(batch_size, self.inter_channels, -1)
        # g_x = g_x.permute(0, 2, 1)
        #
        # # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
        # # phi  =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
        # # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw)
        # theta_x = x.view(batch_size, self.inter_channels, -1)
        # theta_x = theta_x.permute(0, 2, 1)
        # fc = self.fc(theta_x)
        # # phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        # # f = torch.matmul(theta_x, phi_x)
        # # return f
        # # f_div_C = F.softmax(fc, dim=-1)
        # return fc

        ##
        # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c)
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
        # phi  =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
        # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw)
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        # return f
        f_div_C = F.softmax(f, dim=-1)
        # return f_div_C
        # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w)
        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z

    def _gaussian(self, x):
        batch_size = x.size(0)
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = x.view(batch_size, self.in_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)

        if self.sub_sample:
            phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
        else:
            phi_x = x.view(batch_size, self.in_channels, -1)

        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z

    def _dot_product(self, x):
        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        N = f.size(-1)
        f_div_C = f / N

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z

    def _concatenation(self, x):
        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        # (b, c, N, 1)
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
        # (b, c, 1, N)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)

        h = theta_x.size(2)
        w = phi_x.size(3)
        theta_x = theta_x.repeat(1, 1, 1, w)
        phi_x = phi_x.repeat(1, 1, h, 1)

        concat_feature = torch.cat([theta_x, phi_x], dim=1)
        f = self.concat_project(concat_feature)
        b, _, h, w = f.size()
        f = f.view(b, h, w)

        N = f.size(-1)
        f_div_C = f / N

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z


class NONLocalBlock1D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True):
        super(NONLocalBlock1D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=1, mode=mode,
                                              sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NONLocalBlock2D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True):
        super(NONLocalBlock2D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=2, mode=mode,
                                              sub_sample=sub_sample,
                                              bn_layer=bn_layer)

结论

我们提出了一个深度二阶注意力网络 (SAN) 来实现准确的图像 SR。具体来说,非局部增强残差组 (NLRG) 结构允许 SAN 通过在网络中嵌入非局部操作来捕获长距离依赖和结构信息。同时,NLRG 允许通过共享源跳跃连接绕过 LR 图像中丰富的低频信息。除了利用空间特征相关性外,我们还提出了二阶通道注意(SOCA)模块,通过全局协方差池化来学习特征相互依赖性,以获得更具鉴别性的表示。在 BI 和 BD 退化模型的 SR 上的大量实验表明,我们的 SAN 在定量和视觉结果方面的有效性。

你可能感兴趣的:(RGB图像超分,人工智能,神经网络,超分辨率重建,计算机视觉,深度学习)