CV方向通用方法是CNN,NLP领域存在Transformer方法。本文寻求扩展 Transformer的适用性,使其可以作为CV的backbone。
CV和NLP的差异:
文本提出了一个通用的Transformer主干,称为 Swin Transformer,它构建分层特征图并且对图像大小具有线性计算复杂度。
Swin Transformer 的一个关键设计元素是它在连续自注意力层之间的窗口分区的移动。移动的窗口桥接前一层的窗口,提供它们之间的连接。
先将图片分成不重叠的块,每个块成为token
,其特征是原始像素RGB值的串联,然后经过一个线性嵌入层投影到维度C。
在这些token上应用若干修正自注意力计算的Swin Transformer 块。Transformer 块保持令牌的数量,与线性嵌入一起被称为“阶段 1”。
随着网络变深,通过块合并层来减少token的数量。
这些阶段共同产生一个分层表示,具有与典型卷积网络相同的特征图分辨率。因此,所提出的架构可以方便地替换将骨干网络置于现有方法中,用于各种视觉任务。
Swin Transformer块
Swin Transformer 是通过将 Transformer 模块中的标准多头自注意力 (MSA,multi-head self attention) 模块替换为基于移动窗口的模块而构建的,其他层保持不变。由一个基于移动窗口的MSA 模块组成,后跟一个 2层MLP,其间具有GELU非线性。 在每个MSA模块和每个 MLP 之前应用一个正则化层,在每个模块之后应用一个残差连接
为了有效建模,提出在局部窗口内计算自注意力。窗口以不重叠的方式均匀地划分图像。 假设每个窗口包含 M × M M×M M×M个块,全局MSA 模块和基于 h × w 块图像的窗口的计算复杂度为:
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C Ω ( W − M S A ) = 4 h w C 2 + 2 M 2 h w C \Omega(MSA)=4hwC^2+2(hw)^2C\\ \Omega(W-MSA)=4hwC^2+2M^2hwC\\ Ω(MSA)=4hwC2+2(hw)2CΩ(W−MSA)=4hwC2+2M2hwC
(这里的计算复杂度就是qkv那一套理论)可以看到,全局自注意力复杂度正比于块数的平方,而非重叠窗口的自注意力对块数是线性的。
基于窗口的自注意力模块缺乏跨窗口的连接,限制了能力。 为了在保持非重叠窗口的高效计算的同时引入跨窗口连接,作者提出了一种移位窗口分区方法,在连续的 Swin Transformer 块中的两个分区配置之间交替。
第一个模块使用从左上角像素开始的常规窗口分区策略,将 8 × 8 8 × 8 8×8 特征图均匀地划分为大小为 4 × 4 ( M = 4 ) 4 × 4 (M = 4) 4×4(M=4) 的$ 2 × 2$ 窗口。 然后,下一个模块采用从前一层的窗口配置偏移的窗口配置,通过将窗口从规则分区的窗口中移动 ( ⌊ M 2 ⌋ , ⌊ M 2 ⌋ ) (\lfloor\frac{M}{2}\rfloor,\lfloor\frac{M}{2}\rfloor) (⌊2M⌋,⌊2M⌋)个像素。使用移位窗口分区方法,连续的 Swin Transformer 块计算为:
符号说明见上图。
移位窗口分区的一个问题是它会产生更多的窗口,从$\lceil \frac{h}{W} \rceil \times \lceil \frac{h}{W} \rceil $ 到 $(\lceil \frac{h}{W} \rceil+1) \times (\lceil \frac{h}{W} \rceil+1) $ ,有的窗口会比 M × M M×M M×M小。
作者提出了一种更有效的批量计算方法,通过向左上方向循环移位。 在这种移位之后,在特征图中一个批量窗口可能由几个不相邻的子窗口组成,因此采用屏蔽机制将自注意力计算限制在每个子窗口内。使用循环移位,批处理窗口的数量与常规窗口分区的数量相同,因此也是有效的。
B ∈ R M 2 × M 2 B\in R^{M^2\times M^2} B∈RM2×M2, M M M是一个窗口内的块的个数。由于沿每个轴的相对位置在 $[−M + 1, M −1] $范围内,参数化一个更小的偏置矩阵 B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \hat B \in R ^{(2M−1)×(2M−1)} B^∈R(2M−1)×(2M−1),并且 B B B的值取自 B ^ \hat B B^。
本文介绍了 Swin Transformer,一种新的视觉 Transformer,它产生分层特征表示并具有线性计算复杂度。Swin Transformer实现了先进的性能。 希望 Swin Transformer 在各种视觉问题上的强大表现将鼓励视觉和语言信号统一建模。作为 Swin Transformer 的一个关键元素,基于平移窗口的自注意力在视觉问题上被证明是有效和高效的,作者也期待研究其在自然语言处理中的应用。
注意,源代码里面使用的卷积层,把Patch Partition和Linear Embeeding作用合到了一起
window 划分:
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
移位窗口:
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
mask:
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
代码太长了不想看了,跑起来了就行