Swin Transformer总结
1.网络整体架构
下图是Swin Transformer文章中给出的图1,左边是本文要讲的Swin Transformer,右边边是之前讲的Vision Transformer。通过对比至少可以看出两点不同:
接下来,简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图。通过图(a)可以看出整个框架的基本流程如下:
最后对于分类网络,后面还会接上一个Layer Norm层、全局池化层以及全连接层得到最终输出。图中没有画,但源码中是这样做的。
2 Patch Merging详解
前面有说,在每个Stage中首先要通过一个Patch Merging层进行下采样(Stage1除外)。如下图所示,假设输入Patch Merging的是一个4x4大小的单通道特征图(feature map),Patch Merging会将每个2x2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。
这里的高和宽的乘积是patch的个数,56x56个4x4的patch,下采样2x之后维度增加四倍,liner projection只是维度增加2倍,变为28x28x2C,这里的patch数目由56x56变为28x28,patch数目减小,但是patch的尺寸增加由4x4变为8x8,这里的维度C和patch_size的大小是没有关系的,这样之后输出的窗口尺度发生了变化。
窗口尺寸增大↑,token数目减小↓,token数目减小之后,token矩阵的维度增加↑
3 W-MSA详解
引入Windows Multi-head Self-Attention(W-MSA)模块是为了减少计算量。如下图所示,左侧使用的是普通的Multi-head Self-Attention(MSA)模块,对于feature map中的每个像素(或称作token,patch)在Self-Attention计算过程中需要和所有的像素去计算。但在图右侧,在使用Windows Multi-head Self-Attention(W-MSA)模块时,首先将feature map按照MxM(例子中的M=2)大小划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention。
原论文中有给出下面两个公式,这里忽略了Softmax的计算复杂度。:
Ω(MSA)=4hwC2+2(hw)2C (1)
Ω(W−MSA)=4hwC2+2M2hwC (2)
单头Self-Attention的公式:
MSA模块计算量
对于feature map中的每个像素(或称作token,patch),都要通过Wq , Wk , Wv生成对应的query(q) ,query(k) query(v),假设q,k,v的向量长度与feature map的深度C保持一致。那么对应的Q的生成过程如下公式:
Ahw×C⋅WqC×C=Qhw×C
Ahw×C 为所有token拼接在一起得到的矩阵(一共hw个token,每个token的维度为C)
WqC×C 为生成query的变换矩阵
Qhw×C 为所有token通过WqC×C得到的query拼接的矩阵
mAw0
根据矩阵运算的计算量公式可以得到生成Q的计算量为hwxCxC,生成的K和V同理都是hwC2 ,总计3 hwC2 。接下来Q 和K^T相乘,对应计算量为(hw)^2C
Qhw×C⋅KT(C×hw)= Λhw×hw
接下来忽略除以以及softmax的计算量,假设得到Λhw×hw,最后还要乘以V,对应的计算量为 (hw)^2C
那么对应单头的Self-Attention模块,总共需要3hwC^2 + (hw)^2C + (hw)^2C=3hwC^2 + 2(hw)^2C。而在实际使用过程中,使用的是多头的Multi-head Self-Attention模块,在之前的文章中有进行过实验对比,多头注意力模块相比单头注意力模块的计算量仅多了最后一个融合矩阵WO的计算量hwC^2
Bhw×C⋅WOC×C=MAW0 hw×C
所以总共加起来是:4hwC^2 + 2(hw)^2C。
W-MSA模块计算量
对于W-MSA模块首先要将feature map划分到一个个窗口(Windows)中,假设每个窗口的宽高都是M,那么总共会得到h/mxw/m
个窗口,然后对每个窗口内使用多头注意力模块。刚刚计算高为h,宽为w,深度为C的feature map的计算量为4hwC^2 + 2(hw)^2C
这里每个窗口的高为M宽为M,带入公式得:
4(MC)2+2(M)4C
又因为有h/mxw/m
个窗口,则
h/mxw/m×(4(MC)2+2(M)4C)=4hwC2+2M2hwC
故使用W-MSA模块的计算量为: 4hwC2+2M2hwC
4 SW-MSA详解
5 Relative Position Bias详解
Attention(Q,K,V)=SoftMax(QkTd+B)V
偏移从0开始,行、列标加上M-1,之后行标乘上2M-1,然后行标、列标相加
参考:Swin-Transformer网络结构详解_swin transformer-CSDN博客