SwinTransformer

解决Vit的计算复杂度问题:
传统的Vit:
假设图像切成4x4=16的patch,每个patch为16x16=2^8大小,则算self attention时,复杂度为 n 2 ∗ d = ( 2 4 ) 2 ∗ 2 8 = 2 16 n^2*d=(2^4)^2*2^8=2^{16} n2d=(24)228=216
SwinTransformer_第1张图片
本来Vit可以通过减少patch数量的方式来减少计算量,但ViT的patch内部是没有做attention的,patch越大,没有做信息交互的部分就越大。
而SwinTransformer做了几点改善:
(1)local self attention
如左图所示,灰色框是patch,红色框是local area,也就是说,一个patch不会和其他的所有patch做交互,仅仅和附近的patch做交互,这样可以把patch做的很小,同时计算量也不会大
(2)shift window local attention
SwinTransformer_第2张图片

如图所示,每个模块做两次local self attention,第一次如左图,这时候有个缺点,就是window 之间的patch并没有做交互,于是有了第二次的attention:shift window local attention
如layer 1+1所示,重新分配local window,至于如何实现得具体看代码了。

综上所述,swin transformer减小patch大小,使得没有信息交互的部分大大减少,同时通过local window self attention控制了计算量,同时还通过shift window使得window之间也做了信息交互,从而使得之前的transformer精度大大提高。另外作者通过改变patch的大小模拟多尺度操作。

补充:cyclic shift实现shift window attention的方法

SwinTransformer_第3张图片
layer1+1的这种window划分方式,导致每个window大小不一样不利于并行计算
于是要想办法把这些块通过拼接凑成一样的,方便并行计算,同时使用mask防止本不在一个window内的patch计算attention

(1)把上面的形状拼成4个同样大小的正方形

我们可以简化一下,假设每个数字代表一个patch:

1 2 2 4
5 6 6 8
5 6 6 8
13 14 14 16

1、把第一行roll到最后一行,即按行roll一次:

5 6 6 8
5 6 6 8
13 14 14 16
1 2 2 4

2、把第一列roll到最后一列,即按列roll一次:

6 6 8 5
6 6 8 5
14 14 16 13
2 2 4 1

因此,即完成了拼接,然后按四个大window内部计算attention即可

此时,左上角的4个6可以正常计算,但其他3个window都必须加mask

(2)mask attention

以右边这个正方形为例:

8 5
8 5

attention矩阵:

8 5 8 5
8
5
8
5

计算attention时加上mask:(0,1表示)

8 5 8 5
8 1 0 1 0
5 0 1 0 1
8 1 0 1 0
5 0 1 0 1

即可

再补充:torch.roll的用法

就是为了实现上面的roll操作
torch.roll(input_array, roll_num, roll_dim)
第一个参数:要roll的矩阵
第二个参数:roll多少次,一次正向roll一行/列,-1的话就反向roll一行/列
第三个参数:沿着哪一维roll,以二维矩阵为例,dim=0就是把每行当做一个单位按行roll,dim=1就是按列roll
roll的操作类似一个移位操作,按行正向移一次就相当于把矩阵整体向下移动一行移出去的位置填补到上面空出来的位置

补充:整体结构

SwinTransformer_第4张图片
跟resnet类似的层级结构,每个Stage之间加一个下采样模块
注意transformer的下采样跟传统图像下采样有一些区别:
先看看第一个输入:图像(H,W,3)分块,每个patch拉平后经过全连接变成一个新的word embedding
其实这里可以直接用卷积实现,也就是跟patch大小相同的卷积核和stride来卷积,直接就转成了embedding

(N, embedding),N:patch数量=w/patch_size * h/patch_size, embedding=patch_size^2
如果要下采样,首先要将N个patch embedding还原到原来的排列
(h, w, embedding),再用类似pixel shuffle的反操作下采样,然后降维SwinTransformer_第5张图片

你可能感兴趣的:(Transformer,python)