Swin-Transformer详解

Swin-Transformer详解

  • 0. 前言
  • 1. Swin-Transformer结构简介
  • 2. Swin-Transformer结构详解
    • 2.1 Patch Partition
    • 2.2 Patch Merging
    • 2.3 Swin Transformer Block
      • 2.3.1 W-MSA
      • 2.3.2 SW-MSA
  • 3. 模型配置
  • 总结

0. 前言

Swin-Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得ICCV 2021 best paper的荣誉称号。虽然Vision Transformer (ViT)在图像分类方面的结果令人鼓舞,但是由于其低分辨率特性映射和复杂度随图像大小的二次增长,其结构不适合作为密集视觉任务高分辨率输入图像的通过骨干网路。为了最佳的精度和速度的权衡,提出了Swin-Transformer结构。

论文名称Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
原论文地址: https://arxiv.org/abs/2103.14030
官方开源代码地址:https://github.com/microsoft/Swin-Transformer
Pytorch实现代码: pytorch_classification/swin_transformer
Tensorflow2实现代码:tensorflow_classification/swin_transformer

1. Swin-Transformer结构简介

如下图所示为:Swin-Transformer与ViT的对比结构。
Swin-Transformer详解_第1张图片
从上图中可以看出两种网络结构的部分区别:

  1. 采样方式
    • Swin-Transformer开始采用4倍下采样的方式,后续采用8倍下采样,最终采用16倍下采样
    • ViT则一开始就使用16倍下采样
  2. 目标检测机制
    • Swin-Transformer中,通过4倍、8倍、16倍下采样的结果分别作为目标检测所用数据,可以使网络以不同感受野训练目标检测任务,实现对大目标、小目标的检测
    • ViT则只使用16倍下采样,只有单一分辨率特征

接下来,简单看下原论文中给出的关于Swin Transformer(Swin-T)网络的架构图。其中,图(a)表示Swin Transformer的网络结构流程,图(b)表示两阶段的Swin Transformer Block结构。注意:在Swin Transformer中,每个阶段的Swin Transformer Block结构都是2的倍数,因为里面使用的都是两阶段的Swin Transformer Block结构。
Swin-Transformer详解_第2张图片

2. Swin-Transformer结构详解

首先,介绍Swin-Transformer的基础流程。

  1. 输入一张图片 [ H ∗ W ∗ 3 ] [H*W*3] [HW3]
  2. 图片经过Patch Partition层进行图片分割
  3. 分割后的数据经过Linear Embedding层进行特征映射
  4. 将特征映射后的数据输入具有改进的自关注计算的Transformer块(Swin Transformer块),并与Linear Embedding一起被称为第1阶段
  5. 与阶段1不同,阶段2-4在输入模型前需要进行Patch Merging进行下采样,产生分层表示。
  6. 最终将经过阶段4的数据经过输出模块(包括一个LayerNorm层、一个AdaptiveAvgPool1d层和一个全连接层)进行分类。

2.1 Patch Partition

Patch Partition结构是将图片数据进行分割成不重叠的M*M补丁。每个补丁被视为一个“标记”,其特征被设置为原始像素RGB值的串联。在论文中,使用4 × 4的patch大小,因此每个patch的特征维数为4 × 4 × 3 = 48。在此原始值特征上应用线性嵌入层(Linear Embedding),将其投影到任意维度(记为C)。

Swin-Transformer详解_第3张图片
图1 Patch Partition 分割
Swin-Transformer详解_第4张图片
图2 符号标识

注意:在实际操作中,Patch PartitionLinear Embedding通过一个二维的卷积层输出通道为Embedding维度卷积核大小为patch_sizestride大小为patch_size)实现。

2.2 Patch Merging

Patch Merging层主要是进行下采样,产生分层表示。随着网络的深入,通过Patch Merging层来减少令牌的数量。第一个补丁合并层将每组2 × 2相邻补丁的特征进行拼接,并在拼接后的4c维特征上应用线性层。这将令牌的数量减少2×2 = 4的倍数(分辨率的2倍降采样,长和宽分别变为原来的1/2),并将输出维度设置为2C。之后使用Swin Transformer块进行特征变换,分辨率保持在h8 × w8。这第一个块的补丁合并和特征转换被称为“第二阶段”。该过程重复两次,作为“阶段3”和“阶段4”,输出分辨率分别为h16 × w16h32 × w32。由上述的说明,可以得知:数据在经过Patch Merging层后,长宽变为原来的1/2,深度变为原来的2倍。

Swin-Transformer详解_第5张图片

2.3 Swin Transformer Block

Swin Transformer Block 一般以2阶段的串联结构出现,在第一阶段使用Window based Multi-headed Self-Attention(W-MSA),第二阶段使用 Shifted Window based Multi-headed Self-Attention(SW-MSA),根据当前是奇数还是偶数的Swin Transformer Block来选择不同的自关注计算方式。

2.3.1 W-MSA

W-MSA全称为:Window based Multi-headed Self-Attention。从名字可以看出,W-MSA是一个窗口化的多头自注意力,与全局自注意力相比,减少了大量的计算量。直观上来说:假如说是4*4的数据,划分后每个窗口包括 M ∗ M M*M MM 块,这里假设 M = 2 M=2 M=2。如果进行MSA计算大概需要 ( 4 ∗ 4 ) 2 (4*4)^2 442的计算量,而进行W-MSA则大概需要 ( 2 ∗ 2 ) ∗ ( 2 ∗ 2 ) 2 (2*2)*(2*2)^2 22222。这样一对比瞬间计算的复杂度就降低了很多(当然上述只是为了方便简单的理解,下面就详细介绍W-MSA降低了多少复杂度)。

Swin-Transformer详解_第6张图片
MSA (每个红框表示计算一次注意力)
Swin-Transformer详解_第7张图片
W-MSA (红框大小表示计算注意力像素大小)
对于一个 $h*w*C$ 的图像,被分割后每个窗口包括 $M*M$ 块。则对应的MSA和W-MSA的计算如下式所示: $$ Ω(MSA)=4hwC^2 +2(hw)^2C \quad\quad\quad\quad\quad\quad \quad\quad \ \ \ \ \ \ (1) \\\ Ω(W-MSA)=4hwC^2 +2M^2hwC \quad\quad\quad\quad\quad\quad\quad (2) $$
  • h代表feature map的高度
  • w代表feature map的宽度
  • C代表feature map的深度
  • M代表每个窗口(Windows)的大小

注意:前者与长宽 h w 成二次关系,后者在 M 固定时为线性关系(默认为7)。

  • 首先介绍下Self-Attention的计算
    Self-Attention的公式如下所示:
    A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d ) V Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt{d} })V Attention(Q,K,V)=SoftMax(d QKT)V

  • 计算Self-Attention的复杂度
    首先,Q、K、V的计算如下所示:
    Q h w ∗ C = X h w ∗ C ∗ W Q C ∗ C   K h w ∗ C = X h w ∗ C ∗ W K C ∗ C   V h w ∗ C = X h w ∗ C ∗ W V C ∗ C Q^{hw*C}=X^{hw*C}*W_Q^{C*C} \\\ K^{hw*C}=X^{hw*C}*W_K^{C*C} \\\ V^{hw*C}=X^{hw*C}*W_V^{C*C} QhwC=XhwCWQCC KhwC=XhwCWKCC VhwC=XhwCWVCC

    • X h w ∗ C X^{hw*C} XhwC 表示将所有像素(token)拼接在一起得到的矩阵(一共有hw个像素,每个像素的深度为C)
    • W Q C ∗ C W_Q^{C*C} WQCC W K C ∗ C W_K^{C*C} WKCC W V C ∗ C W_V^{C*C} WVCC 分别表示生成Q、K、V的变换矩阵

    因此,由矩阵复杂度计算公式可知Q、K、V的复杂度均为 h w ∗ C 2 hw*C^2 hwC2,此时总复杂度为 3 h w ∗ C 2 3hw*C^2 3hwC2
    然后,由Self-Attention的计算公式可知, Q K T QK^T QKT 的计算量如下所示:
    Q h w ∗ C K T ( C ∗ h w ) = A h w ∗ h w Q^{hw*C}K^{T(C*hw)} = A^{hw*hw} QhwCKT(Chw)=Ahwhw
    因此, Q K T QK^T QKT 的计算量为 C ∗ h w ∗ h w C*hw*hw Chwhw, 即 C ∗ ( h w ) 2 C*(hw)^2 C(hw)2 。忽略 d \sqrt{d} d S o f t M a x SoftMax SoftMax操作, A ∗ V A*V AV的计算量如下所示:
    A h w ∗ h w V h w ∗ C = A t t e n t i o n h w ∗ C A^{hw*hw}V^{hw*C} = Attention^{hw*C} AhwhwVhwC=AttentionhwC
    因此, A ∗ V A*V AV 的计算量为 h w ∗ C ∗ h w hw*C*hw hwChw, 即 C ∗ ( h w ) 2 C*(hw)^2 C(hw)2 。所以,Self-Attention公式的复杂度为 2 C ( h w ) 2 2C(hw)^2 2C(hw)2。Self-Attention总的复杂度为 2 C ( h w ) 2 + 3 h w ∗ C 2 2C(hw)^2+3hw*C^2 2C(hw)2+3hwC2

  • 计算MSA的复杂度
    多头注意力计算复杂度与自注意力复杂度仅缺少一个 ∗ V 0 *V_0 V0 的操作,因此总体复杂度缺少 h w ∗ C 2 hw*C^2 hwC2。所以MSA的复杂度为 2 C ( h w ) 2 + 4 h w ∗ C 2 2C(hw)^2+4hw*C^2 2C(hw)2+4hwC2

  • 计算W-MSA的复杂度
    对于W-MSA模块首先要将feature map划分到一个个窗口(Windows)中,假设每个窗口的宽高都是M,那么总共会得到 h M × w M \frac {h} {M} \times \frac {w} {M} Mh×Mw个窗口,然后对每个窗口内使用多头注意力模块。刚刚计算高为h,宽为w,深度为C的feature map的计算量为 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C,这里每个窗口的高为M宽为M,带入公式得:
    4 ( M C ) 2 + 2 ( M ) 4 C 4(MC)^2 + 2(M)^4C 4(MC)2+2(M)4C
    又因为有 h M × w M \frac {h} {M} \times \frac {w} {M} Mh×Mw个窗口,则:
    h M × w M × ( 4 ( M C ) 2 + 2 ( M ) 4 C ) = 4 h w C 2 + 2 M 2 h w C \frac {h} {M} \times \frac {w} {M} \times (4(MC)^2 + 2(M)^4C)=4hwC^2 + 2M^2 hwC Mh×Mw×(4(MC)2+2(M)4C)=4hwC2+2M2hwC
    故使用W-MSA模块的计算量为:
    4 h w C 2 + 2 M 2 h w C 4hwC^2 + 2M^2 hwC 4hwC2+2M2hwC
    假设feature map的h、w都为112,M=7,C=128,采用W-MSA模块相比MSA模块能够节省约40124743680 FLOPs:
    2 ( h w ) 2 C − 2 M 2 h w C = 2 × 11 2 4 × 128 − 2 × 7 2 × 11 2 2 × 128 = 40124743680 2(hw)^2C-2M^2 hwC=2 \times 112^4 \times 128 - 2 \times 7^2 \times 112^2 \times 128=40124743680 2(hw)2C2M2hwC=2×1124×1282×72×1122×128=40124743680

2.3.2 SW-MSA

由于W-MSA只能关注窗口本身的内容,而不允许跨窗口连接,窗口与窗口之间是无法进行信息传递的。而SW-MSA通过移位窗口的方式,引入跨窗口连接的同时保持非重叠窗口的高效计算。如下图左所示为第 l 层使用W-MSA的方式,而在下一层 l+1 层必定为 SW-MSA的方式(如右图所示),两者合在一起作为一个2阶段的 Swin Transformer Block模块。两幅图进行对比可以发现:右图相对于左图进行了偏移,长宽分别偏移了 M 2 \frac{M}{2} 2M 个像素单位(每个窗口为 M ∗ M M*M MM 像素)。
Swin-Transformer详解_第8张图片
可以看出,偏移后的图像窗口变为了9个。为了提高计算的效率,作者提出了一种更有效的批处理计算方法,即向左上方向循环移位,如下图所示。在此转换之后,批处理窗口可能由特征映射中不相邻的几个子窗口组成,因此采用屏蔽机制(NLP中的masking 屏蔽不应该需要的信息)将自关注计算限制在每个子窗口内。
Swin-Transformer详解_第9张图片
为了更方便地理解左上方向循环移位的操作,这里将具体过程做了一个图,具体内容如下图所示。
Swin-Transformer详解_第10张图片
从上图可以看出,原始图像在进行移位后,A部分移动到右下角,B部分移位到最右边,C部分移位到最下边。然后将每个部分进行合并合并为等同于移位前窗口大小的窗口。
注意:移位后的信息会产生乱序,对于该问题,原文作者使用了Mask的方案。

3. 模型配置

最后,对Swin-Transformer各个版本的参数进行介绍。
Swin-Transformer详解_第11张图片
其中,

  • win. sz 7x7 表示窗口大小为7x7
  • dim表示feature map的channel深度(或者说token的向量长度)
  • head表示多头注意力模块中head的个数

总结

关于Swin-Transformer模型中大多数内容都已经详细介绍了。当然,还有部分不重要的内容以及如何与代码想匹配没有介绍。后续可能会出一篇文章专门介绍相关代码说明。如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

你可能感兴趣的:(机器视觉,深度学习,transformer,深度学习,计算机视觉,人工智能)