paper
code
Left: the overall hierarchical architecture of our proposed CSWin Transformer, Right: the illustration of our proposed CSWin Transformer block.
tokens within Transformer blocks limit attention area and require stacking more blocks to achieve global receptive field
solution apply halo (HaloNet) or shifted window (Swin) to enlarge receptive field
an efficient way cross-shaped window self-attention with horizontal and vertical stripes in parallel
(a) Left: the illustration of the Cross-Shaped Window (CSWin) with stripe width sw for the query point(red dot). Right: the computing of CSWin self-attention, where multi-heads ({ h 1 , . . . , h K h_1, ..., h_K h1,...,hK}) is first split into two groups, then two groups of heads perform self-attention in horizontal and vertical stripes respectively, and finally are concatenated together. (b), ©, (d), and (e) are existing self-attention mechanisms.
given input feature x ∈ R h × w × C x\in R^{h\times w\times C} x∈Rh×w×C linearly projected to K heads, which equally split into 2 parallel groups
each head in 2 groups perform locally self-attention within either horizontal or vertical stripes
x evenly partitioned into horizontal stripes, each with s w × W sw\times W sw×W tokens
x = [ x 1 , x 2 , . . . , x m ] x=[x_1, x_2, ..., x_m] x=[x1,x2,...,xm]
where, x i ∈ R s w × h × C x_i\in R^{sw\times h\times C} xi∈Rsw×h×C, m = h s w m=\frac h{sw} m=swh
calculate self-attention for each k-th head
y k i = W M S A ( x i W k Q , x i W k K , x i W k V ) , i = 1 , . . . , m {y_k}^i=WMSA(x_i{W_k}^Q, x_i{W_k}^K, x_i{W_k}^V), i=1, ..., m yki=WMSA(xiWkQ,xiWkK,xiWkV),i=1,...,m
H C S − W M S A k ( X ) = [ y k 1 , y k 2 , . . . , y k M ] {HCS-WMSA}_k(X)=[{y_k}^1, {y_k}^2, ..., {y_k}^M] HCS−WMSAk(X)=[yk1,yk2,...,ykM]
where, W k ∈ R C × C W_k\in R^{C\times C} Wk∈RC×C is projection matrix that project self-attention results into target output dimension
similarily, for vertical stripes, attention denoted as V C S − W M S A k ( x ) {VCS-WMSA}_k(x) VCS−WMSAk(x)
concat horizontal and vertical attention output together
C S − W M S A ( x ) = c o n c a t ( h e a d 1 , . . . , h e a d k ) W CS-WMSA(x)=concat(head_1, ..., head_k)W CS−WMSA(x)=concat(head1,...,headk)W
where, h e a d k = { H C S − W M S A k ( x ) , k = 1 , . . . , K 2 V C S − W M S A k ( x ) , k = k 2 + 1 , . . . , K head_k=\left\{\begin{aligned}{HCS-WMSA}_k(x)&, k=1, ..., \frac K2\\{VCS-WMSA}_k(x)&, k=\frac k2+1, ..., K\end{aligned}\right. headk=⎩⎪⎪⎨⎪⎪⎧HCS−WMSAk(x)VCS−WMSAk(x),k=1,...,2K,k=2k+1,...,K
adjusted sw small sw for early stages, larger sw for later stages
for HR inputs, h w larger than C in early stages and smaller than C in later stages
with encoder blocks containing cross-shaped-WMSA, transformer encoder computed as
z ^ l = C S − W M S A ( L N ( z l − 1 ) ) + z l − 1 z l = F F N ( L N ( z ^ l ) ) + z ^ l \begin{aligned} \widehat{z}_l&=CS-WMSA(LN(z_{l-1}))+z_{l-1} \\ z_l&=FFN(LN(\widehat{z}_l))+\widehat{z}_l \end{aligned} z lzl=CS−WMSA(LN(zl−1))+zl−1=FFN(LN(z l))+z l
Comparison among different positional encoding mechanisms: APE and CPE introduce the positional information before feeding into the Transformer blocks, while RPE and our LePE operate in each Transformer block. Different from RPE that adds the positional information into the attention calculation, our LePE operates directly upon V and acts as a parallel module. Here we only draw the self-attention part to represent the Transformer block for simplicity.
APE/CPE add positional information before transformer blocks
RPE add positional information within attention calculation
LePE impose positional information upon linearly projected values
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V + E V Attention(Q, K, V)=softmax(\frac {QK^T}{\sqrt{d}})V+EV Attention(Q,K,V)=softmax(dQKT)V+EV
if all connections in E considered, a huge computation cost required, supposed the most vital positional information is from neighborhood of input
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V + D W C o n v ( V ) Attention(Q, K, V)=softmax(\frac {QK^T}{\sqrt{d}})V+DWConv(V) Attention(Q,K,V)=softmax(dQKT)V+DWConv(V)
where, LePE implemented by depth-wise conv: group conv 3x3, groups=embed_dim
in ViT, given input feature map x ∈ R h × w × C x\in R^{h\times w\times C} x∈Rh×w×C, FLOPs of MSA is
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega(MSA)=4hwC^2+2(hw)^2C Ω(MSA)=4hwC2+2(hw)2C
for 1 2 n h e a d s \frac 12 n_{heads} 21nheads horizontal stripes, replace w with sw
Ω ( h ) = 2 w ( s w ) C 2 + ( w × s w ) 2 C \Omega(h)=2w(sw)C^2+(w\times sw)^2C Ω(h)=2w(sw)C2+(w×sw)2C
for 1 2 n h e a d s \frac 12 n_{heads} 21nheads vertical stripes, replace h with sw
Ω ( w ) = 2 h ( s w ) C 2 + ( h × s w ) 2 C \Omega(w)=2h(sw)C^2+(h\times sw)^2C Ω(w)=2h(sw)C2+(h×sw)2C
for CS-WMSA, batchsize × n s t r i p s = w w s \times n_{strips}=\frac w{ws} ×nstrips=wsw or h w s \frac h{ws} wsh
Ω ( C S − W M S A ) = Ω ( h ) × w s w + Ω ( w ) × h s w = 4 h w C 2 + s w ( h + w ) h w C \Omega(CS-WMSA)=\Omega(h)\times \frac w{sw}+\Omega(w)\times \frac h{sw}=4hwC^2+sw(h+w)hwC Ω(CS−WMSA)=Ω(h)×sww+Ω(w)×swh=4hwC2+sw(h+w)hwC
Detailed configurations of different variants of CSWin Transformer. Note that the FLOPs are calculated with 224x224 input.
dataset ImageNet-1K, with augmentation as DeiT
optimizer AdamW: batchsize=1024, 300 epochs, init lr=1e-3, weigh decay=0.05 or 0.1, linear warm-up 20 epochs, cosine decay
stochastic depth 0.1, 0.3, 0.5 for CSWin-T, CSWin-S, CSWin-B
Comparison of different models on ImageNet-1K classification. “*” means the EfficientNet are trained with other input sizes. Here the models are grouped based on the computation complexity.
ImageNet-1K fine-tuning results by pre-training on ImageNet-21K datasets.
framework Mask R-CNN
dataset COCO
Object detection and instance segmentation performance on the COCO val2017 with the Mask R-CNN framework. The FLOPs (G) are measured at resolution 800x1280, and the models are pre-trained on the ImageNet-1K dataset.
framework Cascade Mask R-CNN
dataset COCO
optimizer AdamW: batchsize=16, 36 epochs, init lr=1e-4, weigh decay=0.05, decay rate=0.1 at 27, 33-th epoch
Object detection and instance segmentation performance on the COCO val2017 with Cascade Mask R-CNN.
framework Semantic FPN
dataset ADE20K
optimizer AdamW: batchsize=16, 80K iterations, init lr=1e-4, weight decay=1e-4
framework UPerNet
dataset ADE20K
optimizer AdamW: batchsize=16, 160K iterations, init lr=6e-5, weigh decay=5e-4, linear warm-up 1500 iterations, linear decay
stochastic depth 0.1, 0.3, 0.5 for CSWin-T, CSWin-S, CSWin-B
Performance comparison of different backbones on the ADE20K segmentation task. Two different frameworks semantic FPN and Upernet are used. FLOPs are calculated with resolution 512x2048. “+” means the model is pretrained on ImageNet-21K and finetuned with 640x640 resolution.
component analysis
Ablation study of each component to better understand CSWin Transformer. “SA”, “Arch”,“CTE” denote “Self-Attention”, “Architecture”, and “Convolutional Token Embedding” respectively.
self-attention mechanism
shallow-wide design used in above subsection: 2, 2, 6, 2 blocks for four stages, base channel is 96
apply non-overlapped token embedding and RPE in above models
Ablation study of different self-attention mechanisms and positional encoding mechanisms. “*” denotes applying CPE before every Transformer block.
positional encoding
positional encoding bring performance gain by introducing local inductive bias
LePE perform better on downstream tasks where input resolution varies
stripes width
vary [ s w 1 sw_1 sw1, s w 2 sw_2 sw2, s w 3 sw_3 sw3] of the first three stages of CSWin-T and keep the last stage with s w 4 = 7 sw_4=7 sw4=7
Ablation study on different stripes width. We show the sw of each stage with the form [ s w 1 sw_1 sw1, s w 2 sw_2 sw2, s w 3 sw_3 sw3, s w 4 sw_4 sw4] beside each point and X axis is its corresponding Flops.
with increase of sw, FLOPs increase and accuracy improve greatly at the beginning and slow down when [ s w 1 sw_1 sw1, s w 2 sw_2 sw2, s w 3 sw_3 sw3] are large enough
default setting [1, 2, 7, 7] for [ s w 1 sw_1 sw1, s w 2 sw_2 sw2, s w 3 sw_3 sw3, s w 4 sw_4 sw4] achieve a better trade-off for accuracy and computation cost