cross-scale non-local attention

Image Super-Resolution with Cross-Scale Non-Local Attention
and Exhaustive Self-Exemplars Mining

#cross-scale non-local attention
class CrossScaleAttention(nn.Module):
    def __init__(self, channel=128, reduction=2, ksize=3, scale=3, stride=1, softmax_scale=10, average=True, conv=common.default_conv):
        super(CrossScaleAttention, self).__init__()
        self.ksize = ksize
        self.stride = stride
        self.softmax_scale = softmax_scale
        
        self.scale = scale
        self.average = average
        escape_NaN = torch.FloatTensor([1e-4])
        self.register_buffer('escape_NaN', escape_NaN)
        self.conv_match_1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
        self.conv_match_2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
        self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())
        #self.register_buffer('fuse_weight', fuse_weight)

        

    def forward(self, input):
        #get embedding
        embed_w = self.conv_assembly(input)
        match_input = self.conv_match_1(input)
        
        # b*c*h*w
        shape_input = list(embed_w.size())   # b*c*h*w
        input_groups = torch.split(match_input,1,dim=0)
        # kernel size on input for matching 
        kernel = self.scale*self.ksize   #3*3
        
        # raw_w is extracted for reconstruction 
        #分块,分为3,3的块,滑动取块。
        raw_w = extract_image_patches(embed_w, ksizes=[kernel, kernel],
                                      strides=[self.stride*self.scale,self.stride*self.scale],
                                      rates=[1, 1],
                                      padding='same') # [N, C*k*k, L]
        # raw_shape: [N, C, k, k, L],[N, 128, 3, 3, L]
        raw_w = raw_w.view(shape_input[0], shape_input[1], kernel, kernel, -1)
        raw_w = raw_w.permute(0, 4, 1, 2, 3)    # raw_shape: [N, L, C, k, k]
        raw_w_groups = torch.split(raw_w, 1, dim=0)
        
    
        # downscaling X to form Y for cross-scale matching,缩小3倍
        ref  = F.interpolate(input, scale_factor=1./self.scale, mode='bilinear')
        ref = self.conv_match_2(ref)
        w = extract_image_patches(ref, ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride, self.stride],
                                  rates=[1, 1],
                                  padding='same')
        shape_ref = ref.shape
        # w shape: [N, C, k, k, L]
        w = w.view(shape_ref[0], shape_ref[1], self.ksize, self.ksize, -1)
        w = w.permute(0, 4, 1, 2, 3)    # w shape: [N, L, C, k, k]
        w_groups = torch.split(w, 1, dim=0)


        y = []
        scale = self.softmax_scale  
          # 1*1*k*k
        #fuse_weight = self.fuse_weight

        for xi, wi, raw_wi in zip(input_groups, w_groups, raw_w_groups):
            # normalize
            wi = wi[0]  # [L, C, k, k]
            max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
                                                     axis=[1, 2, 3],
                                                     keepdim=True)),
                               self.escape_NaN)
            wi_normed = wi/ max_wi

            # Compute correlation map
            xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(xi, wi_normed, stride=1)   # [1, L, H, W] L = shape_ref[2]*shape_ref[3]

            yi = yi.view(1,shape_ref[2] * shape_ref[3], shape_input[2], shape_input[3])  # (B=1, C=32*32, H=32, W=32)
            # rescale matching score
            yi = F.softmax(yi*scale, dim=1)
            if self.average == False:
                yi = (yi == yi.max(dim=1,keepdim=True)[0]).float()
            
            # deconv for reconsturction
            wi_center = raw_wi[0]           
            yi = F.conv_transpose2d(yi, wi_center, stride=self.stride*self.scale, padding=self.scale)
            
            yi =yi/6.
            y.append(yi)
      
        y = torch.cat(y, dim=0)
        return y

其中

common.default_conv

 conv=common.default_conv
def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2),stride=stride, bias=bias)

common.BasicBlock

self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())
class BasicBlock(nn.Sequential):
    def __init__(
        self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
        bn=False, act=nn.PReLU()):

        m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
        if bn:
            m.append(nn.BatchNorm2d(out_channels))
        if act is not None:
            m.append(act)

        super(BasicBlock, self).__init__(*m)

假设输入为(N,C,H,W),channel=128

embed_w = self.conv_assembly(input)
self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())

self.conv_assembly,相当于为nn.conv2d(128,128,kernel_size=1,padding=0,stride=1)+nn.PReLU(),其输出为:(N,128,H,W)

match_input = self.conv_match_1(input)
self.conv_match_1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())

reduction=2
self.conv_match_1相当于为nn.conv2d(128,64,kernel_size=1,padding=0,stride=1)+nn.PReLU(),其输出为:(N,64,H,W)

torch.split

input_groups = torch.split(match_input,1,dim=0)
torch.split(tensor, ssplit_size_or_section, dim=0)

torch.split()作用将tensor分成块结构。

参数:

tesnor:input,待分输入

split_size_or_sections:需要切分的大小(int or list ),单个分块的形状大小

dim(int) – 沿着此维进行分 dim=0,按行分,dim=1,按列分。

output:切分后块结构
当split_size_or_sections为int时,tenor结构和split_size_or_sections,正好匹配,那么ouput就是大小相同的块结构。如果按照split_size_or_sections结构,tensor不够了,那么就把剩下的那部分做一个块处理。

extract_image_patches

        # raw_w is extracted for reconstruction 
        raw_w = extract_image_patches(embed_w, ksizes=[kernel, kernel],
                                      strides=[self.stride*self.scale,self.stride*self.scale],
                                      rates=[1, 1],
                                      padding='same') # [N, C*k*k, L]
        # raw_shape: [N, C, k, k, L]
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
    """
    Extract patches from images and put them in the C output dimension.
    :param padding:
    :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
    :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
     each dimension of images
    :param strides: [stride_rows, stride_cols]
    :param rates: [dilation_rows, dilation_cols]
    :return: A Tensor
    # “VALID”,每个补丁必须完全包含在图像中;“SAME”,允许补丁不完整
    """
    assert len(images.size()) == 4
    assert padding in ['same', 'valid']
    batch_size, channel, height, width = images.size()
    
    if padding == 'same':
        images = same_padding(images, ksizes, strides, rates)
    elif padding == 'valid':
        pass
    else:
        raise NotImplementedError('Unsupported padding type: {
     }.\
                Only "same" or "valid" are supported.'.format(padding))

    unfold = torch.nn.Unfold(kernel_size=ksizes,
                             dilation=rates,
                             padding=0,
                             stride=strides)
    patches = unfold(images)
    return patches  # [N, C*k*k, L], L is the total number of such blocks

torch.nn.Unfold

就是从一个批次的输入样本中,提取出滑动的局部区域块,也即是实现所谓局部连接的滑动窗口操作

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

unfold的输入为(N, C, H, W),其中N为batch_size,C是channel个数,H和W分别是channel的长宽
输出形状如: (N,C×∏(kernel_size),L) 。∏为累乘。∏(kernel_size)为kernel_size长和宽的乘积。L是channel的长宽根据kernel_size的长宽滑动裁剪后,得到的区块的数量。

如输入为(1,2,4,4),输出则为(1,8,4)

import torch
inputs = torch.randn(1, 2, 4, 4)
print(inputs.size())
print(inputs)
unfold = torch.nn.Unfold(kernel_size=(2, 2), stride=2)
patches = unfold(inputs)
print(patches.size())
print(patches)

cross-scale non-local attention_第1张图片对代码结果分析,nn.Unfold对输入channel的每一个kernel_size[0]×kernel_size[1]kernel_size[0]×kernel_size[1]的滑动窗口区块做了展平操作。

函数F.interpolate,用来上采样或下采样

def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
    """
    根据给定 size 或 scale_factor,上采样或下采样输入数据input.
    
    当前支持 temporal, spatial 和 volumetric 输入数据的上采样,其shape 分别为:3-D, 4-D 和 5-D.
    输入数据的形式为:mini-batch x channels x [optional depth] x [optional height] x width.

    上采样算法有:nearest, linear(3D-only), bilinear(4D-only), trilinear(5D-only).上采样的时候mode默认是“nearest”
    
    参数:
    - input (Tensor): input tensor
    - size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):输出的 spatial 尺寸.
    - scale_factor (float or Tuple[float]): spatial 尺寸的缩放因子.
    - mode (string): 上采样算法:nearest, linear, bilinear, trilinear, area. 默认为 nearest.
    - align_corners (bool, optional): 如果 align_corners=True,则对齐 input 和 output 的角点像素(corner pixels),保持在角点像素的值. 只会对 mode=linear, bilinear 和 trilinear 有作用. 默认是 False.
    """

如输入为(1,3,64,64),下采样或上采样

    x = Variable(torch.randn([1, 3, 64, 64]))
    y0 = F.interpolate(x, scale_factor=0.5)
    y1 = F.interpolate(x, size=[32, 32])

    y2 = F.interpolate(x, size=[128, 128], mode="bilinear")

    print(y0.shape)
    print(y1.shape)
    print(y2.shape)
#结果
#torch.Size([1, 3, 32, 32])
#torch.Size([1, 3, 32, 32])
#torch.Size([1, 3, 128, 128])

zip

zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表

for xi, wi, raw_wi in zip(input_groups, w_groups, raw_w_groups):

reduce_sum

reduce就是“对矩阵降维”的含义,下划线后面的部分就是降维的方式,在reduce_sum()中就是按照求和的方式对矩阵降维
cross-scale non-local attention_第2张图片

            max_wi = torch.max(torch.sqrt(reduce_sum(torch.pow(wi, 2),
                                                     axis=[1, 2, 3],
                                                     keepdim=True)),
                               self.escape_NaN)
def reduce_sum(x, axis=None, keepdim=False):
    if not axis:
        axis = range(len(x.shape))
    for i in sorted(axis, reverse=True):
        x = torch.sum(x, dim=i, keepdim=keepdim)
    return x

axis:表示在那个维度进行sum操作。
keep_dims:表示是否保留原始数据的维度,False相当于执行完后原始数据就会少一个维度。

sorted() 函数对所有可迭代的对象进行排序操作。
sorted(iterable, key=None, reverse=False)
reverse – 排序规则,reverse = True 降序 , reverse = False 升序(默认)。

>>>sorted([5, 2, 3, 1, 4])
[1, 2, 3, 4, 5]                      # 默认为升序

sort 与 sorted 区别:
sort 是应用在 list 上的方法,sorted 可以对所有可迭代的对象进行排序操作。
list 的 sort 方法返回的是对已经存在的列表进行操作,而内建函数 sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作。

pow() 方法返回 xy(x的y次方) 的值。

same_padding

            # Compute correlation map
            xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1])  # xi: 1*c*H*W
def same_padding(images, ksizes, strides, rates):
    assert len(images.size()) == 4
    batch_size, channel, rows, cols = images.size()
    out_rows = (rows + strides[0] - 1) // strides[0]
    out_cols = (cols + strides[1] - 1) // strides[1]
    effective_k_row = (ksizes[0] - 1) * rates[0] + 1
    effective_k_col = (ksizes[1] - 1) * rates[1] + 1
    padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
    padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
    # Pad the input
    padding_top = int(padding_rows / 2.)
    padding_left = int(padding_cols / 2.)
    padding_bottom = padding_rows - padding_top
    padding_right = padding_cols - padding_left
    paddings = (padding_left, padding_right, padding_top, padding_bottom)
    images = torch.nn.ZeroPad2d(paddings)(images)
    return images

F.conv2d

可以这样理解:nn.Conv2d是[2D卷积层],而F.conv2d是[2D卷积操作]。

import torch
from torch.nn import functional as F

"""手动定义卷积核(weight)和偏置"""
w = torch.rand(16, 3, 5, 5)  # 16种3通道的5乘5卷积核
b = torch.rand(16)  # 和卷积核种类数保持一致(不同通道共用一个bias)

"""定义输入样本"""
x = torch.randn(1, 3, 28, 28)  # 1张3通道的28乘28的图像

"""2D卷积得到输出"""
out = F.conv2d(x, w, b, stride=1, padding=1)  # 步长为1,外加1圈padding
print(out.shape)

out = F.conv2d(x, w, b, stride=2, padding=2)  # 步长为2,外加2圈padding
print(out.shape)
#运行结果
#torch.Size([1, 16, 26, 26])
#torch.Size([1, 16, 14, 14])

yi = F.conv2d(xi, wi_normed, stride=1)   # [1, L, H, W] L = shape_ref[2]*shape_ref[3]

wi = wi[0] # [L, C, k, k]

F.conv_transpose2d

ConvTr

  • List item

anspose2d 实现的是 Conv2d 的逆过程,也就是将一张 m×m 的图片,上采样到 n×n

yi = F.conv_transpose2d(yi, wi_center, stride=self.stride*self.scale, padding=self.scale)

例如
>>> # With square kernels and equal stride
>>> inputs = torch.randn(1, 4, 5, 5) #N,C,H,W
>>> weights = torch.randn(4, 8, 3, 3) #in channel,outchannl,size,size
>>> F.conv_transpose2d(inputs, weights, padding=1)

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, 
						  stride=1, padding=0, output_padding=0, 
						  groups=1, bias=True, dilation=1, padding_mode='zeros')

    in_channels(int):输入张量的通道数
    out_channels(int):输出张量的通道数
    kernel_size(int or tuple):卷积核大小
    stride(int or tuple,optional):卷积步长,决定上采样的倍数
    padding(int or tuple, optional):对输入图像进行padding,输入图像尺寸增加2*padding
    output_padding(int or tuple, optional):对输出图像进行padding,输出图像尺寸增加padding
    groups:分组卷积(必须能够整除in_channels和out_channels)
    bias:是否加上偏置
    dilation:卷积核之间的采样距离(即空洞卷积)
    padding_mode(str):padding的类型
    另外,对于可以传入tuple的参数,tuple[0]是在height维度上,tuple[1]是在width维度上

输出:
Height_out​=(Height_in​+2∗padding−kernel_size​)/strides+1
宽同理。

torch.cat

torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。

 y = torch.cat(y, dim=0)

dim=0 #按维数0拼接(竖着拼)按行拼接
dim=1 #按维数1拼接(横着拼)按列拼接

C=torch.cat((A,B),0)就表示按维数0(行)拼接A和B,也就是竖着拼接,A上B下。此时需要注意:列数必须一致,即维数1数值要相同,这里都是3列,方能列对齐。拼接后的C的第0维是两个维数0数值和,即2+4=6.

C=torch.cat((A,B),1)就表示按维数1(列)拼接A和B,也就是横着拼接,A左B右。此时需要注意:行数必须一致,即维数0数值要相同,这里都是2行,方能行对齐。拼接后的C的第1维是两个维数1数值和,即3+4=7.

从2维例子可以看出,使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。

你可能感兴趣的:(cross-scale non-local attention)