Image Super-Resolution with Cross-Scale Non-Local Attention
and Exhaustive Self-Exemplars Mining
#cross-scale non-local attention
class CrossScaleAttention(nn.Module):
def __init__(self, channel=128, reduction=2, ksize=3, scale=3, stride=1, softmax_scale=10, average=True, conv=common.default_conv):
super(CrossScaleAttention, self).__init__()
self.ksize = ksize
self.stride = stride
self.softmax_scale = softmax_scale
self.scale = scale
self.average = average
escape_NaN = torch.FloatTensor([1e-4])
self.register_buffer('escape_NaN', escape_NaN)
self.conv_match_1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
self.conv_match_2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())
#self.register_buffer('fuse_weight', fuse_weight)
def forward(self, input):
#get embedding
embed_w = self.conv_assembly(input)
match_input = self.conv_match_1(input)
# b*c*h*w
shape_input = list(embed_w.size()) # b*c*h*w
input_groups = torch.split(match_input,1,dim=0)
# kernel size on input for matching
kernel = self.scale*self.ksize #3*3
# raw_w is extracted for reconstruction
#分块,分为3,3的块,滑动取块。
raw_w = extract_image_patches(embed_w, ksizes=[kernel, kernel],
strides=[self.stride*self.scale,self.stride*self.scale],
rates=[1, 1],
padding='same') # [N, C*k*k, L]
# raw_shape: [N, C, k, k, L],[N, 128, 3, 3, L]
raw_w = raw_w.view(shape_input[0], shape_input[1], kernel, kernel, -1)
raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
raw_w_groups = torch.split(raw_w, 1, dim=0)
# downscaling X to form Y for cross-scale matching,缩小3倍
ref = F.interpolate(input, scale_factor=1./self.scale, mode='bilinear')
ref = self.conv_match_2(ref)
w = extract_image_patches(ref, ksizes=[self.ksize, self.ksize],
strides=[self.stride, self.stride],
rates=[1, 1],
padding='same')
shape_ref = ref.shape
# w shape: [N, C, k, k, L]
w = w.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
w_groups = torch.split(w, 1, dim=0)
y = []
scale = self.softmax_scale
# 1*1*k*k
#fuse_weight = self.fuse_weight
for xi, wi, raw_wi in zip(input_groups, w_groups, raw_w_groups):
# normalize
wi = wi[0] # [L, C, k, k]
max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
axis=[1, 2, 3],
keepdim=True)),
self.escape_NaN)
wi_normed = wi/ max_wi
# Compute correlation map
xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
yi = yi.view(1,shape_ref[2] * shape_ref[3], shape_input[2], shape_input[3]) # (B=1, C=32*32, H=32, W=32)
# rescale matching score
yi = F.softmax(yi*scale, dim=1)
if self.average == False:
yi = (yi == yi.max(dim=1,keepdim=True)[0]).float()
# deconv for reconsturction
wi_center = raw_wi[0]
yi = F.conv_transpose2d(yi, wi_center, stride=self.stride*self.scale, padding=self.scale)
yi =yi/6.
y.append(yi)
y = torch.cat(y, dim=0)
return y
其中
conv=common.default_conv
def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2),stride=stride, bias=bias)
self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())
class BasicBlock(nn.Sequential):
def __init__(
self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
bn=False, act=nn.PReLU()):
m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
if bn:
m.append(nn.BatchNorm2d(out_channels))
if act is not None:
m.append(act)
super(BasicBlock, self).__init__(*m)
假设输入为(N,C,H,W),channel=128
embed_w = self.conv_assembly(input)
self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())
self.conv_assembly,相当于为nn.conv2d(128,128,kernel_size=1,padding=0,stride=1)+nn.PReLU(),其输出为:(N,128,H,W)
match_input = self.conv_match_1(input)
self.conv_match_1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
reduction=2
self.conv_match_1相当于为nn.conv2d(128,64,kernel_size=1,padding=0,stride=1)+nn.PReLU(),其输出为:(N,64,H,W)
input_groups = torch.split(match_input,1,dim=0)
torch.split(tensor, ssplit_size_or_section, dim=0)
torch.split()作用将tensor分成块结构。
参数:
tesnor:input,待分输入
split_size_or_sections:需要切分的大小(int or list ),单个分块的形状大小
dim(int) – 沿着此维进行分 dim=0,按行分,dim=1,按列分。
output:切分后块结构
当split_size_or_sections为int时,tenor结构和split_size_or_sections,正好匹配,那么ouput就是大小相同的块结构。如果按照split_size_or_sections结构,tensor不够了,那么就把剩下的那部分做一个块处理。
# raw_w is extracted for reconstruction
raw_w = extract_image_patches(embed_w, ksizes=[kernel, kernel],
strides=[self.stride*self.scale,self.stride*self.scale],
rates=[1, 1],
padding='same') # [N, C*k*k, L]
# raw_shape: [N, C, k, k, L]
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
"""
Extract patches from images and put them in the C output dimension.
:param padding:
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
each dimension of images
:param strides: [stride_rows, stride_cols]
:param rates: [dilation_rows, dilation_cols]
:return: A Tensor
# “VALID”,每个补丁必须完全包含在图像中;“SAME”,允许补丁不完整
"""
assert len(images.size()) == 4
assert padding in ['same', 'valid']
batch_size, channel, height, width = images.size()
if padding == 'same':
images = same_padding(images, ksizes, strides, rates)
elif padding == 'valid':
pass
else:
raise NotImplementedError('Unsupported padding type: {
}.\
Only "same" or "valid" are supported.'.format(padding))
unfold = torch.nn.Unfold(kernel_size=ksizes,
dilation=rates,
padding=0,
stride=strides)
patches = unfold(images)
return patches # [N, C*k*k, L], L is the total number of such blocks
就是从一个批次的输入样本中,提取出滑动的局部区域块,也即是实现所谓局部连接的滑动窗口操作
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
unfold的输入为(N, C, H, W),其中N为batch_size,C是channel个数,H和W分别是channel的长宽
输出形状如: (N,C×∏(kernel_size),L) 。∏为累乘。∏(kernel_size)为kernel_size长和宽的乘积。L是channel的长宽根据kernel_size的长宽滑动裁剪后,得到的区块的数量。
如输入为(1,2,4,4),输出则为(1,8,4)
import torch
inputs = torch.randn(1, 2, 4, 4)
print(inputs.size())
print(inputs)
unfold = torch.nn.Unfold(kernel_size=(2, 2), stride=2)
patches = unfold(inputs)
print(patches.size())
print(patches)
对代码结果分析,nn.Unfold对输入channel的每一个kernel_size[0]×kernel_size[1]kernel_size[0]×kernel_size[1]的滑动窗口区块做了展平操作。
def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
"""
根据给定 size 或 scale_factor,上采样或下采样输入数据input.
当前支持 temporal, spatial 和 volumetric 输入数据的上采样,其shape 分别为:3-D, 4-D 和 5-D.
输入数据的形式为:mini-batch x channels x [optional depth] x [optional height] x width.
上采样算法有:nearest, linear(3D-only), bilinear(4D-only), trilinear(5D-only).上采样的时候mode默认是“nearest”
参数:
- input (Tensor): input tensor
- size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):输出的 spatial 尺寸.
- scale_factor (float or Tuple[float]): spatial 尺寸的缩放因子.
- mode (string): 上采样算法:nearest, linear, bilinear, trilinear, area. 默认为 nearest.
- align_corners (bool, optional): 如果 align_corners=True,则对齐 input 和 output 的角点像素(corner pixels),保持在角点像素的值. 只会对 mode=linear, bilinear 和 trilinear 有作用. 默认是 False.
"""
如输入为(1,3,64,64),下采样或上采样
x = Variable(torch.randn([1, 3, 64, 64]))
y0 = F.interpolate(x, scale_factor=0.5)
y1 = F.interpolate(x, size=[32, 32])
y2 = F.interpolate(x, size=[128, 128], mode="bilinear")
print(y0.shape)
print(y1.shape)
print(y2.shape)
#结果
#torch.Size([1, 3, 32, 32])
#torch.Size([1, 3, 32, 32])
#torch.Size([1, 3, 128, 128])
zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表
for xi, wi, raw_wi in zip(input_groups, w_groups, raw_w_groups):
reduce就是“对矩阵降维”的含义,下划线后面的部分就是降维的方式,在reduce_sum()中就是按照求和的方式对矩阵降维
max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
axis=[1, 2, 3],
keepdim=True)),
self.escape_NaN)
def reduce_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x
axis:表示在那个维度进行sum操作。
keep_dims:表示是否保留原始数据的维度,False相当于执行完后原始数据就会少一个维度。
sorted() 函数对所有可迭代的对象进行排序操作。
sorted(iterable, key=None, reverse=False)
reverse – 排序规则,reverse = True 降序 , reverse = False 升序(默认)。
>>>sorted([5, 2, 3, 1, 4])
[1, 2, 3, 4, 5] # 默认为升序
sort 与 sorted 区别:
sort 是应用在 list 上的方法,sorted 可以对所有可迭代的对象进行排序操作。
list 的 sort 方法返回的是对已经存在的列表进行操作,而内建函数 sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作。
pow() 方法返回 xy(x的y次方) 的值。
# Compute correlation map
xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
def same_padding(images, ksizes, strides, rates):
assert len(images.size()) == 4
batch_size, channel, rows, cols = images.size()
out_rows = (rows + strides[0] - 1) // strides[0]
out_cols = (cols + strides[1] - 1) // strides[1]
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
# Pad the input
padding_top = int(padding_rows / 2.)
padding_left = int(padding_cols / 2.)
padding_bottom = padding_rows - padding_top
padding_right = padding_cols - padding_left
paddings = (padding_left, padding_right, padding_top, padding_bottom)
images = torch.nn.ZeroPad2d(paddings)(images)
return images
可以这样理解:nn.Conv2d是[2D卷积层],而F.conv2d是[2D卷积操作]。
import torch
from torch.nn import functional as F
"""手动定义卷积核(weight)和偏置"""
w = torch.rand(16, 3, 5, 5) # 16种3通道的5乘5卷积核
b = torch.rand(16) # 和卷积核种类数保持一致(不同通道共用一个bias)
"""定义输入样本"""
x = torch.randn(1, 3, 28, 28) # 1张3通道的28乘28的图像
"""2D卷积得到输出"""
out = F.conv2d(x, w, b, stride=1, padding=1) # 步长为1,外加1圈padding
print(out.shape)
out = F.conv2d(x, w, b, stride=2, padding=2) # 步长为2,外加2圈padding
print(out.shape)
#运行结果
#torch.Size([1, 16, 26, 26])
#torch.Size([1, 16, 14, 14])
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] L = shape_ref[2]*shape_ref[3]
ConvTr
anspose2d 实现的是 Conv2d 的逆过程,也就是将一张 m×m 的图片,上采样到 n×n
yi = F.conv_transpose2d(yi, wi_center, stride=self.stride*self.scale, padding=self.scale)
例如
>>> # With square kernels and equal stride
>>> inputs = torch.randn(1, 4, 5, 5) #N,C,H,W
>>> weights = torch.randn(4, 8, 3, 3) #in channel,outchannl,size,size
>>> F.conv_transpose2d(inputs, weights, padding=1)
torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
stride=1, padding=0, output_padding=0,
groups=1, bias=True, dilation=1, padding_mode='zeros')
in_channels(int):输入张量的通道数
out_channels(int):输出张量的通道数
kernel_size(int or tuple):卷积核大小
stride(int or tuple,optional):卷积步长,决定上采样的倍数
padding(int or tuple, optional):对输入图像进行padding,输入图像尺寸增加2*padding
output_padding(int or tuple, optional):对输出图像进行padding,输出图像尺寸增加padding
groups:分组卷积(必须能够整除in_channels和out_channels)
bias:是否加上偏置
dilation:卷积核之间的采样距离(即空洞卷积)
padding_mode(str):padding的类型
另外,对于可以传入tuple的参数,tuple[0]是在height维度上,tuple[1]是在width维度上
输出:
Height_out=(Height_in+2∗padding−kernel_size)/strides+1
宽同理。
torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。
y = torch.cat(y, dim=0)
dim=0 #按维数0拼接(竖着拼)按行拼接
dim=1 #按维数1拼接(横着拼)按列拼接
C=torch.cat((A,B),0)就表示按维数0(行)拼接A和B,也就是竖着拼接,A上B下。此时需要注意:列数必须一致,即维数1数值要相同,这里都是3列,方能列对齐。拼接后的C的第0维是两个维数0数值和,即2+4=6.
C=torch.cat((A,B),1)就表示按维数1(列)拼接A和B,也就是横着拼接,A左B右。此时需要注意:行数必须一致,即维数0数值要相同,这里都是2行,方能行对齐。拼接后的C的第1维是两个维数1数值和,即3+4=7.
从2维例子可以看出,使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。