本文整体设计思路与之前的Vision Transformer Backbone一致,如何获取多尺度特征(PVT)以及如何降低Self-Attention的计算复杂度(SwinTransformer)。
第一点PVT的方法就是一种比较合适的方法,各个Stage将Embedding进行合并,从而减少Embedding的数目,以实现不同Stage具有不同分辨率的金字塔结构。
第二点SwinTransformer通过在Local Window内部计算Self-Attention的方式降低了计算复杂度,同时提出了一种比较复杂的Shift-Window的方式去捕获各个窗口之间的依赖关系。
作者认为SwinTransformer的Shift-Window方法较为复杂,且现代深度学习框架支持性不够好,因此,本文提出了一种更为简单的方法去实现。
整体思路可以认为是PVT+SwinTransformer的结合:在局部窗口内部计算Self-Attention(SwinTransformer),同时对每个窗口内部的特征进行压缩,然后再使用一个全局Attention机制去捕获各个窗口的关系(PVT)。
个人觉得Twins这种思路更加简单,而且也比较高效,SwinTransformer的实现包含了作者大量精巧的设计(阅读过SwinTransformer的代码应该都会觉得其实现是十分巧妙且美妙的),实际上也是比较复杂的。大道至简,还是Twins的设计更符合认知一点,不过其性能还有待验证。
相较于CNN来说,Transformer由于其能高效地捕获远距离依赖的特性,近期在计算机视觉领域也引领了一波潮流。Transformer主要是依靠Self-Attention去捕获各个token之间的关系,但是这种Global Self-Attention的计算复杂度太高 O ( N 2 ) O(N^2) O(N2),不利于在token数目较多的密集检测任务(分割、检测)中使用。
基于以上考虑,目前主流有两种应对方法:
总结下来,目前的VisionTransformer的关键设计在于Spatial Attention的设计。
因此,本文将重点在Spatial Attention的设计上,期望提出一个高效同时简单的Spatial Attention方法。
作者有两个发现:
PVT通过逐步融合各个Patch的方式,形成了一种多尺度的结构,使得其更适合用于密集预测任务例如目标检测或者是语义分割,其继承了ViT和DeiT的Learnable Positional Encoding的设计,所有的Layer均直接使用Global Attention机制,并通过Spatial Reduction的方式去降低计算复杂度。
作者通过实验发现,PVT与SwinTransformer的性能差异主要来自于PVT没有采用一个合适的Positional Encoding方式,通过采用Conditional Positional Encoding(CPE)去替换PVT中的PE,PVT即可获得与当前最好的SwinTransformer相近的性能。关于CPE的具体介绍,可以参见我的另一篇博客:Conditional Positional Encodings for Vision Transformers。
通过提出的Spatially Separable Self-Attention(SSSA)去缓解Self-Attention的计算复杂度过高的问题。SSSA由两个部分组成:Locally-Grouped Self-Attention(LSA)和Global Sub-Sampled Attention(GSA)。
首先将2D feature map划分为多个Sub-Windows,并仅在Window内部进行Self-Attention计算,计算量会大大减少,由 ( H 2 W 2 d ) \left(H^{2} W^{2}d\right) (H2W2d)下降至 O ( k 1 k 2 H W d ) \mathcal{O}\left(k_{1} k_{2} H W d\right) O(k1k2HWd),其中 k 1 = H m , k 2 = W n k_{1}=\frac{H}{m}, k_{2}=\frac{W}{n} k1=mH,k2=nW,当 k 1 , k 2 k_1,k_2 k1,k2固定时,计算复杂度将仅与 H W HW HW呈线性关系。
LSA缺乏各个Window之间的信息交互,比较简单的一个方法是,在LSA后面再接一个Global Self-Attention Layer,这种方法在实验中被证明也是有效的,但是其计算复杂度会较高: O ( H 2 W 2 d ) \mathcal{O}\left(H^{2} W^{2} d\right) O(H2W2d)。
另一个思路是,将每个Window提取一个维度较低的特征作为各个window的表征,然后基于这个表征再去与各个window进行交互,相当于Self-Attention中的Key的作用,这样一来,计算复杂度会下降至: O ( m n H W d ) = O ( H 2 W 2 d k 1 k 2 ) \mathcal{O}(m n H W d)=\mathcal{O}\left(\frac{H^{2} W^{2} d}{k_{1} k_{2}}\right) O(mnHWd)=O(k1k2H2W2d)。
这种方法实际上相当于对feature map进行下采样,因此,被命名为Global Sub-Sampled Attention。
综合使用LSA和GSA,可以取得类似于Separable Convolution(Depth-wise+Point-wise)的效果,整体的计算复杂度为: O ( H 2 W 2 d k 1 k 2 + k 1 k 2 H W d ) \mathcal{O}\left(\frac{H^{2} W^{2} d}{k_{1} k_{2}}+k_{1} k_{2} H W d\right) O(k1k2H2W2d+k1k2HWd)。同时有: H 2 W 2 d k 1 k 2 + k 1 k 2 H W d ≥ 2 H W d H W \frac{H^{2} W^{2} d}{k_{1} k_{2}}+k_{1} k_{2} H W d \geq 2 H W d \sqrt{H W} k1k2H2W2d+k1k2HWd≥2HWdHW,当且仅当 k 1 ⋅ k 2 = H W k_{1} \cdot k_{2}=\sqrt{H W} k1⋅k2=HW。
考虑到分类任务中, H = W = 224 H=W=224 H=W=224是比较常规的设置,同时,不是一般性使用方形框,则有 k 1 = k 2 k_1=k_2 k1=k2,第一个stage的feature map大小为56,可得 k 1 = k 2 = 56 = 7 k_1=k_2=\sqrt{56}=7 k1=k2=56=7。
当然可以针对各个Stage去设定其窗口大小,不过为了简单性,所有的 k k k均设置为7。
整个Transformer Block可以被表示为:
z ^ i j l = LSA ( LayerNorm ( z i j l − 1 ) ) + z i j l − 1 z i j l = FFN ( LayerNorm ( z ^ i j l ) ) + z ^ i j l z ^ l + 1 = GSA ( LayerNorm ( z l ) ) + z l z l + 1 = FFN ( LayerNorm ( z ^ l + 1 ) ) + z ^ l + 1 , i ∈ { 1 , 2 , … , m } , j ∈ { 1 , 2 , … , n } \begin{array}{l} \hat{\mathbf{z}}_{i j}^{l}=\text { LSA }\left(\text { LayerNorm }\left(\mathbf{z}_{i j}^{l-1}\right)\right)+\mathbf{z}_{i j}^{l-1} \\ \mathbf{z}_{i j}^{l}=\text { FFN }\left(\text { LayerNorm }\left(\hat{\mathbf{z}}_{i j}^{l}\right)\right)+\hat{\mathbf{z}}_{i j}^{l} \\ \hat{\mathbf{z}}^{l+1}=\text { GSA }\left(\text { LayerNorm }\left(\mathbf{z}^{l}\right)\right)+\mathbf{z}^{l} \\ \mathbf{z}^{l+1}=\text { FFN }\left(\text { LayerNorm }\left(\hat{\mathbf{z}}^{l+1}\right)\right)+\hat{\mathbf{z}}^{l+1}, \\ i \in\{1,2, \ldots, m\}, j \in\{1,2, \ldots, n\} \end{array} z^ijl= LSA ( LayerNorm (zijl−1))+zijl−1zijl= FFN ( LayerNorm (z^ijl))+z^ijlz^l+1= GSA ( LayerNorm (zl))+zlzl+1= FFN ( LayerNorm (z^l+1))+z^l+1,i∈{1,2,…,m},j∈{1,2,…,n}同时在每个Stage的第一个Block中会引入CPVT中的的PEG对位置信息进行编码。
通过以上实验结果可以看出,Twins系列在各个任务上均取得了与SwinTransformer相当甚至是超过的水平,不过相比较而言,除了Small,Twins的模型参数比SwinTransformer系列都稍微大一点,而且运行速度似乎也没有明显优势。
本文提出了两种Vision Transformer Backbone,同时适用于图片级的分类任务或是其他密集预测任务,并且在分类、分割、检测等多个任务上,均取得了新的SOTA。