swin-transformer学习笔记1——window_partition函数的理解

swin-transformer学习笔记1——window_partition函数的理解

功能如下所示
swin-transformer学习笔记1——window_partition函数的理解_第1张图片

原文关于这部分的代码如下

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

首先我们看一下输入x的几个属性,B是Batch,H是图片的高,W是图片的宽,C是通道数
我们知道在pytorch中view操作其实就是先把原来张量变成一维的,然后嵌入新的格式中。

x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)

这里首先B是不动的,也就是说,每张图片的元素是不变的,C也是不变的,也就是说每个像素点的通道值是不变的。
那么我们可以把这两个维度先忽略不看
我们发现,其实就是把一个维度为(H,W)的张量,分成了(H // window_size, window_size, W // window_size, window_size)维度的张量。
那么我们举个例子来分析这个过程

import torch
a=torch.arange(36).view(6,-1)
print (a)
print(a.view(3,-1))

输出

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]])

也就是说第一步将图片水平切分了
swin-transformer学习笔记1——window_partition函数的理解_第2张图片
然后我们进一步切分:

import torch
a=torch.arange(36).view(6,-1)
print (a)
print(a.view(3,2,-1))

输出:

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]])
tensor([[[ 0,  1,  2,  3,  4,  5],
         [ 6,  7,  8,  9, 10, 11]],

        [[12, 13, 14, 15, 16, 17],
         [18, 19, 20, 21, 22, 23]],

        [[24, 25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34, 35]]])

可以看到这一步是把每一行都分了出来,并且在垂直方向上进行了分块。
之后:

import torch
a=torch.arange(36).view(6,-1)
print (a)
print(a.view(3,2,3,-1))

输出:

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23],
        [24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]])
tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],

         [[ 6,  7],
          [ 8,  9],
          [10, 11]]],


        [[[12, 13],
          [14, 15],
          [16, 17]],

         [[18, 19],
          [20, 21],
          [22, 23]]],


        [[[24, 25],
          [26, 27],
          [28, 29]],

         [[30, 31],
          [32, 33],
          [34, 35]]]])

这时,就是每一行都按照我们需要的window_size进行分块了,但是有一点过于细了。所以我们要把不同行的对应元素结合起来,于是我们需要用到permute操作。

x.permute(0,2,1,3)

这里等于是先将图片按照我们的要求切成(H//window_size,W//window_size)的形式,然后在细分。

你可能感兴趣的:(深度学习,pytorch,transformer)