论文地址:Second-Order Attention Network for Single Image Super-Resolution
代码地址:daitao/SAN: Second-order Attention Network for Single Image Super-resolution (CVPR-2019)
提出一种基于二阶统计信息的通道注意力机制,产生更好的表征能力,同时,模型对non-local机制也进行了优化,针对low-level任务,直接将non-local应用在整个图会导致计算量过大,于是采用了patch进行region-level的non-local机制。
现阶段的基于CNN的方法大多关注在如何设计更宽更深的网络,忽略了探索中间层的特征相关性,阻碍了CNN的表示能力。
SENet中的Channel Attention只关注了一阶统计量(eg. 全局池化),忽略了高于一阶的统计量,阻碍了网络的判别能力
对于给定输入,将其特征reshape为 X w i t h C × S , w h e r e s = W H X~with~C\times S,~where ~s=WH X with C×S, where s=WH,计算样本的协方差矩阵
Σ = X I ˉ X T , w h e r e I ˉ = 1 s ( I − 1 s 1 ) I a n d 1 a r e t h e s × s i d e n t i t y m a t r i x a n d m a t r i x o f a l l o n e s \begin{aligned} \Sigma=&\mathrm{X}\bar{\mathrm{I}}\mathrm{X}^{T},\quad\quad \mathrm{where~\bar{\mathbf{I}}}=\frac{1}{s}(\mathbf{I}-\frac{1}{s}\mathbf{1})\\ &\mathbf{I}\mathrm{~and~1~are~the~}s\times s\mathrm{~identity~matrix~and~matrix~of~all~ones} \end{aligned} Σ=XIˉXT,where Iˉ=s1(I−s11)I and 1 are the s×s identity matrix and matrix of all ones
由于协方差归一化能提高模型的辨别性的表征能力,因此对 Σ \Sigma Σ进行归一化。而因为 Σ \Sigma Σ是对称半正定的矩阵,其具有特征值分解(EIG)如下:
Σ = U Λ U T \Sigma=U\Lambda U^T Σ=UΛUT
U U U是一个正交矩阵, Λ = d i a g ( λ 1 , ⋅ ⋅ ⋅ , λ C ) \Lambda =diag(λ1,···,λC) Λ=diag(λ1,⋅⋅⋅,λC)是具有非递增阶特征值的对角矩阵,协方差归一化可以转为:
Y ^ = Σ α = U Λ α U T \hat{Y}=\Sigma^{\alpha}=U\Lambda^{\alpha} U^T Y^=Σα=UΛαUT
当 α < 1 \alpha < 1 α<1,会非线性的缩小特征值大于1的值,并放大那些小于1的值。在贡献的[1]参考文献中表示, α = 0.5 \alpha=0.5 α=0.5具有最好的表征能力。
class Covpool(Function):
"""
Global Covariance pooling layer
"""
@staticmethod
def forward(ctx, input):
x = input
batchSize = x.data.shape[0]
# hwc
dim = x.data.shape[1]
h = x.data.shape[2]
w = x.data.shape[3]
# s
M = h * w
# Σ = X I_hat X^T,而I为SxS的矩阵,所以x需要reshape为dim,M
x = x.reshape(batchSize, dim, M)
# I_hat=1/s(I-1/s 1)=(-1/s/s)*1+1/s*I,I and 1 are the s × s identity matrix and matrix of all ones
I_hat = (1. / M) * torch.eye(M, M, device=x.device)+(-1. / M / M) * torch.ones(M, M, device=x.device)
# 将I_hat转到和x的shape一样,因为存在batch,所以需要repeat
I_hat = I_hat.view(1, M, M).repeat(batchSize, 1, 1).type(x.dtype)
"""计算协方差矩阵Σ = X I_hat X^T"""
# y = x I_hat x^T
# x的shape为b,c,m,所以transpose是2,3维度
# x.bmm(I_hat) 表示 x 和 I_hat 的批量矩阵乘法
y = x.bmm(I_hat).bmm(x.transpose(1, 2))
# 用于反向传播
ctx.save_for_backward(input, I_hat)
return y
@staticmethod
def backward(ctx, grad_output):
input, I_hat = ctx.saved_tensors
x = input
batchSize = x.data.shape[0]
dim = x.data.shape[1]
h = x.data.shape[2]
w = x.data.shape[3]
M = h * w
x = x.reshape(batchSize, dim, M)
grad_input = grad_output + grad_output.transpose(1, 2)
grad_input = grad_input.bmm(x).bmm(I_hat)
grad_input = grad_input.reshape(batchSize, dim, h, w)
return grad_input
Towards Faster Training of Global Covariance Pooling Networks by Iterative Matrix Square Root Normalization受到这篇论文的启发,文章中利用了Newton-Schulz迭代来加速协方差归一化的计算。对于 Σ 1 / 2 = U Λ 1 / 2 U T \Sigma^{1/2}=U\Lambda^{1/2} U^T Σ1/2=UΛ1/2UT,通过令 Y 0 = Σ , Z 0 = I Y_0=\Sigma,Z_0=I Y0=Σ,Z0=I,交替迭代更新如下:
Y n = 1 2 Y n − 1 ( 3 I − Z n − 1 Y n − 1 ) , Z n = 1 2 ( 3 I − Z n − 1 Y n − 1 ) Z n − 1 . ) \begin{array}{rl} \mathbf{Y}_n&=\frac12\mathbf{Y}_{n-1}(3\mathbf{I}-\mathbf{Z}_{n-1}\mathbf{Y}_{n-1}),\\ \mathbf{Z}_n&=\frac12(3\mathbf{I}-\mathbf{Z}_{n-1}\mathbf{Y}_{n-1})\mathbf{Z}_{n-1}.)\\ \end{array} YnZn=21Yn−1(3I−Zn−1Yn−1),=21(3I−Zn−1Yn−1)Zn−1.)
由于Newton-Schulz迭代只局部收敛,为了保证收敛性 ,首先对 Σ \Sigma Σ进行pre-norm归一化
Σ ^ = 1 t r ( Σ ) Σ \hat{\Sigma}=\frac{1}{tr(\Sigma)}\Sigma\\ Σ^=tr(Σ)1Σ
其中 t r ( Σ ) = ∑ i C λ i tr(\Sigma)=\sum_i^C\lambda_i tr(Σ)=∑iCλi表示 Σ \Sigma Σ的迹。在这种情况下,能推断出 ∣ ∣ Σ − I ∣ ∣ 2 ||\Sigma − I||_2 ∣∣Σ−I∣∣2等于 ( Σ − I ) (\Sigma − I) (Σ−I)最大奇异值。 1 − λ i ∑ i λ i 1−\frac{λ_i}{∑i λ_i} 1−∑iλiλi小于1,满足收敛条件.
再迭代之后,采用后补偿法,补偿在pre-norm中引起的数值波动,最后得到归一化协方差矩阵
Y ^ = t r ( Σ ) Y N , N i s f i n a l i t e r \hat{Y}=\sqrt{tr(\Sigma)}Y_N,N ~is~final ~iter Y^=tr(Σ)YN,N is final iter
class Sqrtm(Function):
@staticmethod
def forward(ctx, input, iterN):
x = input
batchSize = x.data.shape[0]
dim = x.data.shape[1]
dtype = x.dtype
# 3I
I3 = 3.0 * torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype)
# 计算tr(\Sigma),乘以单位对角阵然后求和
normA = (1.0 / 3.0) * x.mul(I3).sum(dim=1).sum(dim=1)
# pre_norm
A = x.div(normA.view(batchSize, 1, 1).expand_as(x))
# 让Y,Z具有相应的输出尺寸大小
Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad=False, device=x.device)
Z = torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, iterN, 1, 1)
if iterN < 2:
ZY = 0.5 * (I3 - A)
Y[:, 0, :, :] = A.bmm(ZY)
else:
"""iter1"""
# 0.5(3I-Z_N-1Y_N-1)
ZY = 0.5 * (I3 - A)
# Y_1=0.5Y_0(3I-Z_0Y_0)=0.5A*(I3-A)
Y[:, 0, :, :] = A.bmm(ZY)
Z[:, 0, :, :] = ZY
for i in range(1, iterN - 1):
# 3I-Z_N-1 Z_Y-1
ZY = 0.5 * (I3 - Z[:, i - 1, :, :].bmm(Y[:, i - 1, :, :]))
Y[:, i, :, :] = Y[:, i - 1, :, :].bmm(ZY)
Z[:, i, :, :] = ZY.bmm(Z[:, i - 1, :, :])
#最后一次迭代不用更新Z,直接求Y
ZY = 0.5 * Y[:, iterN - 2, :, :].bmm(I3 - Z[:, iterN - 2, :, :].bmm(Y[:, iterN - 2, :, :]))
# y_hat=\sqrt{ tr(\Sigma) } Y_N,后补偿
y = ZY * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
ctx.save_for_backward(input, A, ZY, normA, Y, Z)
ctx.iterN = iterN
return y
@staticmethod
def backward(ctx, grad_output):
input, A, ZY, normA, Y, Z = ctx.saved_tensors
iterN = ctx.iterN
x = input
batchSize = x.data.shape[0]
dim = x.data.shape[1]
dtype = x.dtype
der_postCom = grad_output * torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
der_postComAux = (grad_output * ZY).sum(dim=1).sum(dim=1).div(2 * torch.sqrt(normA))
I3 = 3.0 * torch.eye(dim, dim, device=x.device).view(1, dim, dim).repeat(batchSize, 1, 1).type(dtype)
if iterN < 2:
der_NSiter = 0.5 * (der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace))
else:
dldY = 0.5 * (der_postCom.bmm(I3 - Y[:, iterN - 2, :, :].bmm(Z[:, iterN - 2, :, :])) -
Z[:, iterN - 2, :, :].bmm(Y[:, iterN - 2, :, :]).bmm(der_postCom))
dldZ = -0.5 * Y[:, iterN - 2, :, :].bmm(der_postCom).bmm(Y[:, iterN - 2, :, :])
for i in range(iterN - 3, -1, -1):
YZ = I3 - Y[:, i, :, :].bmm(Z[:, i, :, :])
ZY = Z[:, i, :, :].bmm(Y[:, i, :, :])
dldY_ = 0.5 * (dldY.bmm(YZ) -
Z[:, i, :, :].bmm(dldZ).bmm(Z[:, i, :, :]) -
ZY.bmm(dldY))
dldZ_ = 0.5 * (YZ.bmm(dldZ) -
Y[:, i, :, :].bmm(dldY).bmm(Y[:, i, :, :]) -
dldZ.bmm(ZY))
dldY = dldY_
dldZ = dldZ_
der_NSiter = 0.5 * (dldY.bmm(I3 - A) - dldZ - A.bmm(dldY))
grad_input = der_NSiter.div(normA.view(batchSize, 1, 1).expand_as(x))
grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1)
for i in range(batchSize):
grad_input[i, :, :] += (der_postComAux[i] \
- grad_aux[i] / (normA[i] * normA[i])) \
* torch.ones(dim, device=x.device).diag()
return grad_input, None
原始non-local机制参见Non-local Neural Networks,由于原始的是global 的non-local,当特征图较大,会导致计算量复杂;又根据经验表明,在合适的局部大小进行非局部操作能很好的适合low-level任务。因此论文采用了region-level non local。
将图片切成四块,每一块中进行region-level non-local机制,最后在拼接在一起。
class Nonlocal_CA(nn.Module):
def __init__(self, in_feat=64, inter_feat=32, reduction=8,sub_sample=False, bn_layer=True):
super(Nonlocal_CA, self).__init__()
# second-order channel attention
self.soca=SOCA(in_feat, reduction=reduction)
# nonlocal module
self.non_local = (NONLocalBlock2D(in_channels=in_feat,inter_channels=inter_feat, sub_sample=sub_sample,bn_layer=bn_layer))
self.sigmoid = nn.Sigmoid()
def forward(self,x):
## divide feature map into 4 part
batch_size,C,H,W = x.shape
H1 = int(H / 2)
W1 = int(W / 2)
nonlocal_feat = torch.zeros_like(x)
feat_sub_lu = x[:, :, :H1, :W1]
feat_sub_ld = x[:, :, H1:, :W1]
feat_sub_ru = x[:, :, :H1, W1:]
feat_sub_rd = x[:, :, H1:, W1:]
nonlocal_lu = self.non_local(feat_sub_lu)
nonlocal_ld = self.non_local(feat_sub_ld)
nonlocal_ru = self.non_local(feat_sub_ru)
nonlocal_rd = self.non_local(feat_sub_rd)
nonlocal_feat[:, :, :H1, :W1] = nonlocal_lu
nonlocal_feat[:, :, H1:, :W1] = nonlocal_ld
nonlocal_feat[:, :, :H1, W1:] = nonlocal_ru
nonlocal_feat[:, :, H1:, W1:] = nonlocal_rd
return nonlocal_feat
vanilla non-local机制如下:
class _NonLocalBlockND(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian',
sub_sample=True, bn_layer=True):
super(_NonLocalBlockND, self).__init__()
assert dimension in [1, 2, 3]
assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation']
# print('Dimension: %d, mode: %s' % (dimension, mode))
self.mode = mode
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
if dimension == 3:
conv_nd = nn.Conv3d
max_pool = nn.MaxPool3d
bn = nn.BatchNorm3d
elif dimension == 2:
conv_nd = nn.Conv2d
max_pool = nn.MaxPool2d
sub_sample = nn.Upsample
bn = nn.BatchNorm2d
else:
conv_nd = nn.Conv1d
max_pool = nn.MaxPool1d
bn = nn.BatchNorm1d
self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if bn_layer:
self.W = nn.Sequential(
conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
bn(self.in_channels)
)
nn.init.constant_(self.W[1].weight, 0)
nn.init.constant_(self.W[1].bias, 0)
else:
self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0)
nn.init.constant_(self.W.weight, 0)
nn.init.constant_(self.W.bias, 0)
self.theta = None
self.phi = None
self.concat_project = None
# self.fc = nn.Linear(64,2304,bias=True)
# self.sub_bilinear = nn.Upsample(size=(48,48),mode='bilinear')
# self.sub_maxpool = nn.AdaptiveMaxPool2d(output_size=(48,48))
if mode in ['embedded_gaussian', 'dot_product', 'concatenation']:
self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
if mode == 'embedded_gaussian':
self.operation_function = self._embedded_gaussian
elif mode == 'dot_product':
self.operation_function = self._dot_product
elif mode == 'concatenation':
self.operation_function = self._concatenation
self.concat_project = nn.Sequential(
nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
nn.ReLU()
)
elif mode == 'gaussian':
self.operation_function = self._gaussian
if sub_sample:
self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
if self.phi is None:
self.phi = max_pool(kernel_size=2)
else:
self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))
def forward(self, x):
'''
:param x: (b, c, t, h, w)
:return:
'''
output = self.operation_function(x)
return output
def _embedded_gaussian(self, x):
batch_size,C,H,W = x.shape
# x_sub = self.sub_bilinear(x) # bilinear downsample
# x_sub = self.sub_maxpool(x) # maxpool downsample
##
# g_x = x.view(batch_size, self.inter_channels, -1)
# g_x = g_x.permute(0, 2, 1)
#
# # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
# # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
# # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw)
# theta_x = x.view(batch_size, self.inter_channels, -1)
# theta_x = theta_x.permute(0, 2, 1)
# fc = self.fc(theta_x)
# # phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
# # f = torch.matmul(theta_x, phi_x)
# # return f
# # f_div_C = F.softmax(fc, dim=-1)
# return fc
##
# g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
# phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
# f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
# return f
f_div_C = F.softmax(f, dim=-1)
# return f_div_C
# (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _gaussian(self, x):
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = x.view(batch_size, self.in_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
if self.sub_sample:
phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
else:
phi_x = x.view(batch_size, self.in_channels, -1)
f = torch.matmul(theta_x, phi_x)
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _dot_product(self, x):
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
theta_x = theta_x.permute(0, 2, 1)
phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
f = torch.matmul(theta_x, phi_x)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
def _concatenation(self, x):
batch_size = x.size(0)
g_x = self.g(x).view(batch_size, self.inter_channels, -1)
g_x = g_x.permute(0, 2, 1)
# (b, c, N, 1)
theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
# (b, c, 1, N)
phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
h = theta_x.size(2)
w = phi_x.size(3)
theta_x = theta_x.repeat(1, 1, 1, w)
phi_x = phi_x.repeat(1, 1, h, 1)
concat_feature = torch.cat([theta_x, phi_x], dim=1)
f = self.concat_project(concat_feature)
b, _, h, w = f.size()
f = f.view(b, h, w)
N = f.size(-1)
f_div_C = f / N
y = torch.matmul(f_div_C, g_x)
y = y.permute(0, 2, 1).contiguous()
y = y.view(batch_size, self.inter_channels, *x.size()[2:])
W_y = self.W(y)
z = W_y + x
return z
class NONLocalBlock1D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True):
super(NONLocalBlock1D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=1, mode=mode,
sub_sample=sub_sample,
bn_layer=bn_layer)
class NONLocalBlock2D(_NonLocalBlockND):
def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True):
super(NONLocalBlock2D, self).__init__(in_channels,
inter_channels=inter_channels,
dimension=2, mode=mode,
sub_sample=sub_sample,
bn_layer=bn_layer)
我们提出了一个深度二阶注意力网络 (SAN) 来实现准确的图像 SR。具体来说,非局部增强残差组 (NLRG) 结构允许 SAN 通过在网络中嵌入非局部操作来捕获长距离依赖和结构信息。同时,NLRG 允许通过共享源跳跃连接绕过 LR 图像中丰富的低频信息。除了利用空间特征相关性外,我们还提出了二阶通道注意(SOCA)模块,通过全局协方差池化来学习特征相互依赖性,以获得更具鉴别性的表示。在 BI 和 BD 退化模型的 SR 上的大量实验表明,我们的 SAN 在定量和视觉结果方面的有效性。