paper
code
main limitations of ViT
Feature visualization of ResNet50, ViT-L/16 and T2T-ViT-24 trained on ImageNet. Green boxes highlight learned low-level structure features such as edges and lines; red boxes highlight invalid feature maps with zero or too large values. Note the feature maps visualized here for ViT and T2T-ViT are not attention maps, but image features reshaped from tokens. For better visualization, we scale the input image to size 1024x1024 or 2048x2048.
ResNet capture desired local structure (edges, lines, textures, etc.) progressively from bottom layer (conv1) to middle layer (conv25)
in ViT, structure information poorly modeled while global relations (e.g., the whole dog) captured by all attention blocks
note that ViT ignores local structure when directly splitting images to tokens with fixed length.
many channels in ViT have zero value
note that backbone of ViT is not efficient as ResNets and offers limited feature richness when training samples are not enough.
main contributions of T2T-ViT
Comparison between T2T-ViT with ViT, ResNets and MobileNets when trained from scratch on ImageNet. Left: performance curve of MACs vs. top-1 accuracy. Right: performance curve of model size vs. top-1 accuracy.
The overall network architecture of T2T-ViT. In the T2T module, the input image is first soft split as patches, and then unfolded as a sequence of tokens T 0 T_0 T0. The length of tokens is reduced progressively in the T2T module (we use 2 iterations here and output T f T_f Tf). Then the T2T-ViT backbone takes the fixed tokens as input and outputs the predictions.
aim to overcome limitation of simple tokenization in ViT
progressively structurize an image to tokens and model local structure information, so length of tokens reduced iteratively
Illustration of T2T process. The tokens T i T_i Ti are re-structurized as an image I i I_i Ii after transformation and reshaping; then I i I_i Ii is split with overlapping to tokens T i + 1 T_{i+1} Ti+1 again. Specifically, as shown in the pink panel, the four tokens (1, 2, 4, 5) of the input Ii are concatenated to form one token in T i + 1 T_{i+1} Ti+1. The T2T transformer can be a normal Transformer layer or other efficient transformers like Performer layer at limited GPU memory.
given a sequence of tokens T i T_i Ti from preceding transformer layer, transform T i ′ T_i' Ti′ by self-attention block
T i ′ = M L P ( M S A ( T i ) ) T_i'=MLP(MSA(T_i)) Ti′=MLP(MSA(Ti))
where, T i ∈ R L × C T_i\in R^{L\times C} Ti∈RL×C, T i ′ ∈ R L × C T_i'\in R^{L\times C} Ti′∈RL×C
tokens T i ′ T_i' Ti′ will be reshaped as an image in spatial dimension
I i = R e s h a p e ( T i ′ ) I_i=Reshape(T_i') Ii=Reshape(Ti′)
where, R e s h a p e ( . ) Reshape(.) Reshape(.) re-organize T i ′ ∈ R L × C T_i'\in R^{L\times C} Ti′∈RL×C to I i ∈ H × W × C I_i\in H\times W\times C Ii∈H×W×C, with L = H × W L=H\times W L=H×W
model local structure information and reduce length of tokens
similar to convolution operation without conv filters
to avoid information loss in generating tokens from re-structurizated image, split image into patches with overlapping
T i + 1 = S S ( I i ) T_i+1=SS(I_i) Ti+1=SS(Ii)
where, S S ( . ) SS(.) SS(.) is soft split operation, implemented by nn.Unfold
in nn.Unfold, given a tenser X ∈ B × C × H × W X\in B\times C\times H\times W X∈B×C×H×W, a kxk-size kernel apply on F to capture X 1 ∈ C × X_1\in C\times X1∈C×, which then reshaped into X 1 ′ ∈ C k 2 X_1'\in Ck^2 X1′∈Ck2
get output tensor Y ∈ B × C k 2 × H 0 × W 0 Y\in B\times Ck^2\times H_0\times W_0 Y∈B×Ck2×H0×W0, with H 0 = ⌊ H − k + 2 p s + 1 ⌋ H_0=\lfloor \frac {H-k+2p}s+1\rfloor H0=⌊sH−k+2p+1⌋, W 0 = ⌊ W − k + 2 p s + 1 ⌋ W_0=\lfloor \frac {W-k+2p}s+1\rfloor W0=⌊sW−k+2p+1⌋
similarly, given I i ∈ H × W × C I_i\in H\times W\times C Ii∈H×W×C, output tensor T i + 1 ∈ L 0 × C k 2 T_{i+1}\in L_0\times Ck^2 Ti+1∈L0×Ck2 got, with L 0 = ⌊ H − k + 2 p s + 1 ⌋ × ⌊ W − k + 2 p s + 1 ⌋ L_0=\lfloor \frac {H-k+2p}s+1\rfloor \times \lfloor \frac {W-k+2p}s+1\rfloor L0=⌊sH−k+2p+1⌋×⌊sW−k+2p+1⌋
after soft split, output tokens are fed for the next T2T process
based on transformer block, with 2 extra components
for input image I 0 I_0 I0, only apply soft split at first to split it to tokens: T 1 = S S ( I 0 ) T_1=SS(I_0) T1=SS(I0)
after last T2T module, output tokens T f T_f Tf has fixed length, so T2T-ViT backbone can model global relation on T f T_f Tf
since many channels in vanilla ViT are invalid, plan to find an efficient backbone to reduce redundancy and improve feature richness
5 designs from CNNs to ViT
key findings
design a deep-narrow structure with a small channel dimension and a hidden dimension d but more layers b
for tokens T f T_f Tf from the last T2T module, concatenate a class token and add sinusoidal position embedding
T f 0 = [ t c l s ; T f ] + P E , P E ∈ R ( L + 1 ) × d {T_f}_0=[t_cls; T_f]+PE, PE\in R^{(L+1)\times d} Tf0=[tcls;Tf]+PE,PE∈R(L+1)×d
T f i = M L P ( M S A ( T f i − 1 ) ) , i = 1 , 2 , . . . , b {T_f}_i=MLP(MSA({T_f}_{i-1})), i=1, 2,...,b Tfi=MLP(MSA(Tfi−1)),i=1,2,...,b
Y = F C ( L N ( T f b ) ) Y=FC(LN({T_f}_b)) Y=FC(LN(Tfb))
Structure details of T2T-ViT. T2T-ViT-14/19/24 have comparable model size with ResNet50/101/152. T2T-ViT-7/12 have comparable model size with MobileNetV1/V2. For T2T transformer layer, we adopt Transformer layer for T2T-ViTt-14 and Performer layer for T2T-ViT-14 at limited GPU memory. For ViT, ‘S’ means Small, ‘B’ is Base and ‘L’ is Large. ‘ViT-S/16’ is a variant from original ViT-B/16 with smaller MLP size and layer depth.
dataset ImageNet
data augmentation mixup, cutmix, for both CNNs and ViTs
optimizer AdamW: batchsize=512 or 1024, 310 epochs, cosine lr decay
Comparison between T2T-ViT and ViT by training from scratch on ImageNet.
Comparison between our T2T-ViT and ResNet on ImageNet. T2T-ViTt-14: using Transformer in T2T module. T2TViT-14: using Performer in T2T module. “*” means we train the model with our training scheme for fair comparisons.
Comparison between our lite T2T-ViT and MobileNet. Models with “-Distilled” are taught by teacher model with the method as DeiT.
fine-tuning in transfer learning
dataset CIFAR10, CIFAR100
optimizer SGD: 60 epochs, cosine lr decay
The results of fine-tuning the pretrained T2T-ViT to downstream datasets: CIFAR10 and CIFAR100.
T2T-ViT achieve higher performance than ViT with smaller model sizes on downstream datasets
Transfer of some common designs in CNN to ViT&T2T-ViT, including DenseNet, Wide-ResNet, SE module, ResNeXt, Ghost operation. The same color means the correspond transfer. All models are trained from scratch on ImageNet. “*” means we reproduce the model with our training scheme for fair comparisons.
key findings
Ablation study results on T2T module, Deep-Narrow(DN) structure.
T2T module
T2T-ViT-14-woT2T: the same T2T-ViT backbone but without T2T module
T2T-ViTc-14: T2T module replaced by 3 conv layers with kernel size (7, 3, 3) and stride (4, 2, 2)
deep-narrow structure
T2T-ViT-d768-4: a shallow-wide structure with hidden dimension of 768 and 4 layers, with similar model size and MACs as T2T-ViT-14
after replacing deep-narrow with shallow-wide structure, 2.7% decrease on ImageNet
deep-narrow structure is crucial for T2T-ViT