论文:https://arxiv.org/abs/2103.14030
代码:语义分割、图像分类、目标检测
相关阅读:更多
作者观察到Transformer在语言领域和视觉领域的 两种模式差异:
三 | 点 | 几 | 了 | , | 做 | 撚 | 啊 | 做 | |
---|---|---|---|---|---|---|---|---|---|
三 | |||||||||
点 | |||||||||
几 | |||||||||
了 | |||||||||
, | |||||||||
做 | |||||||||
撚 | |||||||||
啊 | |||||||||
做 |
为了克服这些问题,作者提出Swin Transformer 用来构建层次特征图,由于只在每个局部窗口(如红色方块)内计算 self-attention,因此具有线性计算复杂度的图像大小。如下图所示,Swin Transformer通过从较小的patches(底层的灰色方块)开始,并逐渐合并更深的Transformer层中的邻近patches(例如,底层4个小灰块合并为中间层一个较大的灰块),构造了一个层次表示。有了这些层次特征图,Swin Transformer模型可以方便地利用高级技术进行密集预测,如特征金字塔网络(FPN)或U-Net。
线性计算复杂度是通过在非重叠窗口内局部计算 self-attention 来实现的,该窗口对图像进行划分(用红色标出)。每个窗口中的patches数量是固定的,因此复杂度与图像大小成线性关系。这些优点使Swin Transformer适合作为各种视觉任务的通用 主干 ,与以前基于Transformer的架构形成单一分辨率和二次复杂度的特征图形成对比。
Swin Transformer 的一个关键设计元素是它在连续的self-attention层之间的窗口分区的移动,如下图所示。移动的窗口桥接了上一层的窗口,提供了它们之间的连接,极大地提高了建模能力。这个策略在现实世界的延迟方面也是有效的:一个窗口内的所有查询patches共享同一个键集,这有助于硬件上的内存访问。相比之下,早期的基于滑动窗口的self-attention方法在一般硬件上受到低延迟的影响,这是由于不同查询像素的不同键集所致。实验表明,所提出的移位窗口方法比滑动窗口方法具有更低的延迟。
Patch Partition: | Images分割成不重叠的patches,patch相当于语言Transformer的token,使用4×4的patch大小,因此每个patch的特征维数为4×4×3 = 48。 |
---|---|
Stage 1: | 每个patch的特征应用线性嵌入层(Linear Embedding),将其投射到任意维度(记为C)。 |
这些patches token上应用Swin Transformer Block。Transformer Block维持 token 的数量( H 4 × W 4 \frac{H}{4}×\frac{W}{4} 4H×4W)。 | |
Stage 2: | 第一个patches合并层将相邻4个patches的特征 拼接起来(减少 token 的数量),并在4c维的拼接特征上应用一个线性层。这将 patches 的数量减少了 4的倍(分辨率的2倍下采样),并且输出维数设置为2C。 |
然后应用 Swin Transformer Block 进行特征变换,分辨率保持在 H 8 × W 8 \frac{H}{8}×\frac{W}{8} 8H×8W。 | |
Stage 3: | 同 Stage 2,输出分辨率分别为 H 16 × W 16 \frac{H}{16}×\frac{W}{16} 16H×16W。 |
Stage 4: | 同 Stage 2,输出分辨率分别为 H 32 × W 32 \frac{H}{32}×\frac{W}{32} 32H×32W。 |
这些Stage共同产生一个层次表示(类似于CNN的层次结构),具有与典型卷积网络(如VGG和ResNet)相同的特征图分辨率。因此,该体系结构可以方便地替代现有方法中的 主干 网来完成各种视觉任务。
Swin Transformer的构建方法是将Transformer块中的标准Multi-head self-attention(MSA)模块替换为基于移动窗口的模块,其他层保持不变。如下图所示,Swin Transformer模块由基于MSA的平移窗口模块和介于GELU非线性之间的2层MLP组成。在每个MSA模块和每个MLP之前应用一个 LN层(层归一化),在每个模块之后应用一个残差连接。
连续Swin Transformer Block计算为
z ^ l = W-MSA ( LN ( z l − 1 ) ) + z l − 1 z l = MLP ( LN ( z l ) ) + z ^ l \begin{aligned} &\hat z^l= \text{W-MSA}(\text{LN}(z^{l−1}))+ z^{l−1}\\ &z^l= \text{MLP}(\text{LN}( z^l))+ \hat z^l\end{aligned} z^l=W-MSA(LN(zl−1))+zl−1zl=MLP(LN(zl))+z^l
z ^ l + 1 = SW-MSA ( LN ( z l ) ) + z l z l + 1 = MLP ( LN ( z l + 1 ) ) + z ^ l + 1 \begin{aligned} &\hat z^{l+1}= \text{SW-MSA}(\text{LN}(z^l))+ z^l\\ &z^{l+1}= \text{MLP}(\text{LN}(z^{l+1}))+ \hat z^{l+1} \end{aligned} z^l+1=SW-MSA(LN(zl))+zlzl+1=MLP(LN(zl+1))+z^l+1
其中, W-MSA \text{W-MSA} W-MSA为窗口MSA, SW-MSA \text{SW-MSA} SW-MSA为移动窗口MSA。前者解决规模问题,后者解决计算复杂度问题。
为解决全局MSA 二次复杂度 ,作者提出在 局部窗口中计算self-attention 。窗口被安排以不重叠的方式均匀地分割图像。假设每个窗口包含 M × M M × M M×M个patches,则一个全局MSA模块的计算复杂度为
Ω ( MSA ) = 4 h w C 2 + 2 ( h w ) 2 C Ω(\text{MSA}) = 4hwC^2+ 2(hw)^2C Ω(MSA)=4hwC2+2(hw)2C
基于 h × w h×w h×w 个patches图像的窗口计算复杂度为
Ω ( W-MSA ) = 4 h w C 2 + 2 M 2 h w C Ω(\text{W-MSA}) = 4hwC^2+ 2M^2hwC Ω(W-MSA)=4hwC2+2M2hwC
其中前者是 h w hw hw的二次,而后者是C的线性,当M固定时(默认设置为7)。全局MSA计算对于大型 h w hw hw来说通常是负担不起的,而基于W-MSA是可扩展的。
基于W-MSA 模块缺乏跨窗口的连接,这限制了它的建模能力。为了引入跨窗口连接,同时保持非重叠窗口的高效计算,我们提出了一种移动窗口分区方法 SW-MSA,该方法在连续Swin Transformer Block的两个分区配置之间交替使用。
如上图所示,W-MSA模块使用从左上角像素开始的常规窗口划分策略,8 × 8 特征图被均匀地分成4个4×4的窗口(M = 4)。然后,SW-MSA模块采用的窗口配置是从W-MSA模块传过来的,通过用( ∣ M 2 ∣ , ∣ M 2 ∣ |\frac{M}{2}|,|\frac{M}{2}| ∣2M∣,∣2M∣)像素替换常规划分的窗口。采用移动窗口划分方法,连续Swin Transformer Block计算。
W-MSA:
z ^ l = W-MSA ( LN ( z l − 1 ) ) + z l − 1 z l = MLP ( LN ( z l ) ) + z ^ l \begin{aligned} &\hat z^l= \text{W-MSA}(\text{LN}(z^{l−1}))+ z^{l−1}\\ &z^l= \text{MLP}(\text{LN}( z^l))+ \hat z^l\end{aligned} z^l=W-MSA(LN(zl−1))+zl−1zl=MLP(LN(zl))+z^l
SW-MSA:
z ^ l + 1 = SW-MSA ( LN ( z l ) ) + z l z l + 1 = MLP ( LN ( z l + 1 ) ) + z ^ l + 1 \begin{aligned} &\hat z^{l+1}= \text{SW-MSA}(\text{LN}(z^l))+ z^l\\ &z^{l+1}= \text{MLP}(\text{LN}(z^{l+1}))+ \hat z^{l+1} \end{aligned} z^l+1=SW-MSA(LN(zl))+zlzl+1=MLP(LN(zl+1))+z^l+1
其中,W-MSA和SW-MSA分别表示为使用规则MSA和移动窗口划分MSA的配置。
移窗划分的一个问题是,它将导致更多的窗口,从 [ h M ] × [ w M ] [\frac{h}{M}] \times [\frac{w}{M}] [Mh]×[Mw] 到 ( [ h M ] + 1 ) × ( [ w M ] + 1 ) ([\frac{h}{M}]+1) \times ([\frac{w}{M}]+1) ([Mh]+1)×([Mw]+1) 移动的配置,和一些窗口将小于M×M。一个简单的解决方案是将小窗口填充到M × M的大小,并在计算attention 时屏蔽填充值。当常规划分的窗口数量很小时,例如2 × 2,使用这种简单的解决方案增加的计算量是相当大的(2 × 2→3 × 3,是常规分区的2.25倍)。在这里,我们提出了一种更高效的批处理计算方法,即向左上角循环移动,如图4所示。在此移动之后,一个批处理窗口可能由几个在特征映射中不相邻的子窗口组成,因此采用掩蔽机制将 self-attention 计算限制在每个子窗口内。使用循环移动,批处理窗口的数量与常规划分窗口的数量相同,因此也是高效的。
计算self-attention,在计算相似度时,每个head包含一个相对位置偏差 B ∈ R M 2 × M 2 B∈\mathbb R^{M^2×M^2} B∈RM2×M2:
Attention ( Q , K , V ) = SoftMax ( Q K ⊤ d + B ) V \text{Attention}(Q, K, V ) = \text{SoftMax}(\frac{QK^{\top}}{\sqrt d} + B)V Attention(Q,K,V)=SoftMax(dQK⊤+B)V
其中 Q , K , V ∈ R M 2 × d Q, K, V∈\mathbb R^{M^2×d} Q,K,V∈RM2×d分别为查询、键和值矩阵;d为查询/键维, m 2 m^2 m2为窗口中的patches数。由于每个轴上的相对位置在 [ − M + 1 , M − 1 ] [−M + 1, M−1] [−M+1,M−1]范围内,因此我们参数化一个较小的偏置矩阵 B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \hat B∈\mathbb R^{(2M−1)×(2M−1)} B^∈R(2M−1)×(2M−1), B B B中的值取 B ^ \hat B B^。
我们观察到相对于没有这个偏置项或使用绝对位置嵌入的对等项的显著改进,如表4所示。在输入中进一步添加绝对位置嵌入(如在[19]中)会略微降低性能,因此在我们的实现中没有采用它。
在训练前学习到的相对位置偏差也可以通过双三次插值初始化模型,用于不同窗口大小的微调。
基本模型Swin-B,其模型规模和计算复杂度类似于ViTB/DeiT-B。还介绍了Swin-T、Swin-S和Swin-L,它们分别是模型尺寸和计算复杂度分别为0.25×、0.5×和2×的版本。需要注意的是,Swin-T和Swin-S的复杂度分别与ResNet-50 (DeiT-S)和ResNet-101相似。窗口大小默认设置为M = 7。对于所有实验,每个头部的查询维度为d = 32,每个MLP的扩展层为α = 4。这些模型变体的架构超参数是:
其中C为第一阶段隐藏层的信道数。ImageNet图像分类的模型大小、理论计算复杂度(FLOPs)和模型变体的吞吐量列在下表中。