[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers

paper
code

Content

    • Contribution
    • Method
        • model architecture
        • focal self-attention (FSA)
          • window-wise attention
          • focal transformer encoder
          • computational complexity
        • architecture variants
    • Experiment
        • image classification
        • object detection and instance segmentation
        • semantic segmentation
        • ablation study

Contribution

  • propose Focal self-attention (FSA) with fine attention locally and coarse attention globally

Method

model architecture

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第1张图片
Model architecture for our Focal Transformers. As highlighted in light blue boxes, our main innovation is the proposed focal self-attention mechanism in each Transformer layer.

focal self-attention (FSA)

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第2张图片
Left: Visualization of the attention maps of the three heads at the given query patch (blue) in the first layer of the DeiT-Tiny model. Right: An illustrative depiction of focal self-attention mechanism. Three granularity levels are used to compose the attention region for the blue query.

FSA attend fine-grain tokens only locally instead of attending all tokens at fine-grain
cover as many regions as standard self-attention but with much less cost

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第3张图片
The size of receptive field (yaxis) with the increase of used tokens (x-axis) for standard and our focal selfattention. For focal self-attention, we assume increasing the window granularity by factor 2 gradually but no more than 8. Note that the y-axis is logarithmic.

for a query position, when use gradually coarser-grain for its far surroundings, FSA have significantly larger receptive fields at the cost of attending the same number of visual tokens than baseline.
focal mechanism enable long-range self-attention with much less time and memory cost

window-wise attention

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第4张图片
An illustration of our focal self-attention at window level. Each of the finest square cell represents a visual token either from the original feature map or the squeezed ones. Suppose we have an input feature map of size 20x20. We first partition it into 5x5 windows of size 4x4. Take the 4x4 blue window in the middle as the query, we extract its surroundings tokens at multiple granularity levels as its keys and values. For the first level, we extract the 8x8 tokens which are closest to the blue window at the finest grain. Then at the second level, we expand the attention region and pool the surrounding 2x2 sub-windows, which results in 6x6 pooled tokens. At the third level, we attend even larger region covering the whole feature map and pool 4x4 sub-windows. Finally, these three levels of tokens are concatenated to compute the keys and values for the 4x4=16 tokens (queries) in the blue window.

firstly define 3 terms for clarity

  1. focal levels L number of granularity levels that extract tokens for focal self-attention
  2. focal window size s w l {s_w}^l swl size of sub-window on which summarized tokens got at level l ∈ 1 , . . . , L l\in {1, ..., L} l1,...,L
  3. focal region size s r l {s_r}^l srl number of sub-windows horizontally and vertically in attended regions at level l

specify focal self-attention proceeded in 2 main steps

  1. sub-window pooling
    given input feature map x ∈ R h × w × C x\in R^{h\times w\times C} xRh×w×C, split into s w l × s w l {s_w}^l\times {s_w}^l swl×swl-size sub-windows
    x ^ = R e s h a p e ( x ) ∈ R h s w l × w s w l × C × ( s w l × s w l ) \widehat{x}=Reshape(x)\in R^{\frac h{{s_w}^l}\times \frac w{{s_w}^l}\times C\times({s_w}^l\times {s_w}^l)} x =Reshape(x)Rswlh×swlw×C×(swl×swl)
    use a linear layer to pool each sub-window spatially
    x l = f p l ( x ^ ) ∈ R h s w l × w s w l x^l={f_p}^l(\widehat{x})\in R^{\frac h{{s_w}^l}\times \frac w{{s_w}^l}} xl=fpl(x )Rswlh×swlw
  2. attention computation
    obtained pooled feature maps { x l } 1 L {{\{x^l\}}_1}^L {xl}1L, compute q, k, v with linear projection layers
    Q = f q ( x 1 ) K = { K l } 1 L = f k ( x 1 , . . . , x L ) V = { V l } 1 L = f v ( x 1 , . . . , x L ) \begin{aligned} Q&=f_q(x^1) \\ K&={{\{K^l\}}_1}^L=f_k({x^1, ..., x^L}) \\ V&={{\{V^l\}}_1}^L=f_v({x^1, ..., x^L}) \end{aligned} QKV=fq(x1)={Kl}1L=fk(x1,...,xL)={Vl}1L=fv(x1,...,xL)
    first extract surrounding tokens for each query token in feature map
    note that tokens inside a window partition s p × s p s_p\times s_p sp×sp share the same set of surroundings
    for queries in i-th window Q i ∈ R s p × s p × C Q_i\in R^{s_p\times s_p\times C} QiRsp×sp×C, extract s r l × s r l {s_r}^l\times {s_r}^l srl×srl keys, values from K l K_l Kl, V l V_l Vl around the window where query lie in
    then gather keys and values for all L levels to obtain
    K i = K i 1 , . . . , K i L ∈ R s × C , V i = V i 1 , . . . , V i L ∈ R s × C K_i={{K_i}^1, ..., {K_i}^L}\in R^{s\times C}, V_i={{V_i}^1, ..., {V_i}^L}\in R^{s\times C} Ki=Ki1,...,KiLRs×C,Vi=Vi1,...,ViLRs×C
    where, s is sum of focal regions from all levels, i.e., s = ∑ l = 1 L ( s r l ) 2 s=\sum_{l=1}^L({s_r}^l)^2 s=l=1L(srl)2
    note that a strict version of focal self-attention requires to exclude overlapped regions across different levels
    finally, include a relative position bias and compute focal self-attention
    A t t e n t i o n ( Q i , K i , V i ) = s o f t m a x ( Q i K i T d + B ) V i Attention(Q_i, K_i, V_i)=softmax(\frac {Q_iK_i^T}{\sqrt{d}}+B)V_i Attention(Qi,Ki,Vi)=softmax(d QiKiT+B)Vi
    where, B = { B l } 1 L B={{\{B^l\}}_1}^L B={Bl}1L is a learnable relative position bias, consisting of L subsets for L focal levels
  • for the first level, parameterize B to B 1 ∈ R ( 2 s p − 1 ) × ( 2 s p − 1 ) B_1\in R^{(2s_p-1)\times(2s_p-1)} B1R(2sp1)×(2sp1)
    where, horizontal and vertical position range in [- s p s_p sp+1, s p s_p sp-1]
  • for the other levels, because of different granularity to queries, treat all queries inside a window equally
    use B l ∈ R s r l × s r l B_l\in R^{{s_r}^l\times {s_r}^l} BlRsrl×srl to represent relative position bias between query window, each of s r l × s r l {s_r}^l\times {s_r}^l srl×srl pooled token
focal transformer encoder

with encoder blocks containing FSA, transformer encoder computed as
z ^ l = F S A ( L N ( z l − 1 ) ) + z l − 1 z l = F F N ( L N ( z ^ l ) ) + z ^ l \begin{aligned} \widehat{z}_l&=FSA(LN(z_{l-1}))+z_{l-1} \\ z_l&=FFN(LN(\widehat{z}_l))+\widehat{z}_l \end{aligned} z lzl=FSA(LN(zl1))+zl1=FFN(LN(z l))+z l

computational complexity

in ViT, given input feature map x ∈ R h × w × C x\in R^{h\times w\times C} xRh×w×C, FLOPs of MSA is
Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega(MSA)=4hwC^2+2(hw)^2C Ω(MSA)=4hwC2+2(hw)2C
given input feature map x ∈ R h × w × C x\in R^{h\times w\times C} xRh×w×C, h s p × w s p \frac h{s_p}\times \frac w{s_p} sph×spw sub-windows at focal level l
for pooling on each s w l × s w l {s_w}^l\times {s_w}^l swl×swl-size sub-window
Ω ( p o o l ) = ( s w l ) 2 C \Omega(pool)=({s_w}^l)^2C Ω(pool)=(swl)2C
for aggregation of sub-windows in h × w h\times w h×w feature map of each layer
Ω ( a g g r ) = h w C \Omega(aggr)=hwC Ω(aggr)=hwC
attention cost for a s p × s p s_p\times s_p sp×sp-size query window
Ω ( a t t n w i n ) = ( s p ) 2 C ∑ l ( s r l ) 2 \Omega(attn_{win})=(s_p)^2C\sum_{l}({s_r}^l)^2 Ω(attnwin)=(sp)2Cl(srl)2
attention cost in whole feature map
Ω ( a t t n f e a t ) = h w C ∑ l ( s r l ) 2 \Omega(attn_{feat})=hwC\sum_{l}({s_r}^l)^2 Ω(attnfeat)=hwCl(srl)2
to sum up, for FSA
Ω ( F S A ) = n l e v e l s × Ω ( a g g r ) + Ω ( a t t n f e a t ) = h w C ( L + ∑ l ( s r l ) 2 ) \Omega(FSA)=n_{levels}\times\Omega(aggr)+\Omega(attn_{feat})=hwC(L+\sum_{l}({s_r}^l)^2) Ω(FSA)=nlevels×Ω(aggr)+Ω(attnfeat)=hwC(L+l(srl)2)

architecture variants

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第5张图片
Model configurations for our focal Transformers. We introduce three configurations Focal-Tiny, Focal-Small and Focal-Base with different model capacities.

Experiment

image classification

dataset ImageNet-1K, with augmentation and regularization as DeiT
optimizer AdamW: batchsize=1024, 300 epochs, init lr=1e-3, weigh decay=0.05, linear warm-up 20 epochs, cosine decay
stochastic depth 0.2, 0.2, 0.3 for Focal-T, Focal-S, Focal-B
max gradient norm clipped to 5.0

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第6张图片
Comparison of image classification on ImageNet-1K for different models. Except for ViT-Base/16, all other models are trained and evaluated on 224x224 resolution.

object detection and instance segmentation

framework Mask R-CNN, Cascade Mask R-CNN
dataset COCO 2017
optimizer AdamW: 12 or 36 epochs, init lr=1e-4, weigh decay=0.05
stochastic depth 0.2, 0.2, 0.3 for Focal-T, Focal-S, Focal-B

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第7张图片
Comparisons with CNN and Transformer baselines and SoTA methods on COCO object detection. The box mAP ( A P b AP^b APb) and mask mAP ( A P m AP^m APm) are reported for RetinaNet and Mask R-CNN trained with 1x schedule.

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第8张图片
COCO object detection and segmentation results with RetinaNet and Mask R-CNN. All models are trained with 3x schedule and multi-scale inputs (MS). The numbers before and after “/” at column 2 and 3 are the model size and complexity for RetinaNet and Mask R-CNN, respectively.

dataset COCO 2017
optimizer AdamW: 36 epochs, init lr=1e-4, weigh decay=0.05
stochastic depth 0.2, 0.2, 0.3 for Focal-T, Focal-S, Focal-B

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第9张图片
Comparison with ResNet-50, Swin-Tiny across different object detection methods. We use Focal-Tiny as the backbone and train all models using 3x schedule.

semantic segmentation

dataset ADE20K
optimizer AdamW: batchsize=16, 160K iterations, init lr=6e-5, weigh decay=0.01, polynomial decay
scaling ratio [0.5, 0.75, 1.0, 1.25, 1.5, 1.75], for multi-scale evaluation

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第10张图片
Comparison with SoTA methods for semantic segmentation on ADE20K val set. Both single- and multi-scale evaluations are reported at the last two columns. “\neq” means pretrained on ImageNet-22K.

ablation study

window size
one question is that whether increasing window size further help model learning giving enlarged receptive fields

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第11张图片
Impact of different window sizes (WSize). We alter the default size 7 to 14 and observe consistent improvements for both methods.

necessity of window shift
window shift operations enable cross-window interactions between two successive layers

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第12张图片
Impact of window shift (W-Shift) on Swin Transformer and Focal Transformer. Tiny models are used.

short- and long-interaction
ablate Focal-Tiny model to

  1. Focal-Tiny-Window merely performing attention inside each window
  2. Focal-Tiny-Local attending additional fine-grain surrounding tokens
  3. Focal-Tiny-Global attending extra coarse-grain squeezed tokens

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第13张图片
Ablating Focal-Tiny model by adding local, global and both interactions, respectively. Blue bars are for image classification and orange bars indicate object detection performance. Both local and global interactions are essential to obtain good performance.

model depth
since focal attention prompt local and global interactions at each Transformer layer, one question is that whether less number of layers needed to obtain similar modeling capacity as those without global interactions
reduce number of Transformer layers at stage 3 in Swin-Tiny, Focal-Tiny from 6 to 4, 2

[2107] [NIPS 2021] Focal Self-attention for Local-Global Interactions in Vision Transformers_第14张图片
Impact of the change of model depth. We gradually reduce the number of transformer layers at the third stage from original 6 to 4 and further 2. It apparently hurts the performance but our Focal Transformers has much slower drop rate than Swin Transformer.

你可能感兴趣的:(Vision,Transformer,计算机视觉,深度学习)