AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
2020年Google团队提出
transformer在CV领域应用的里程碑著作
ViT:Vision Transformer
论文地址:https://arxiv.org/abs/2010.11929
代码地址:https://github.com/google-research/vision_transformer
发表时间:[Submitted on 22 Oct 2020 (v1), last revised 3 Jun 2021 (this version, v2)]
发表会议:Computer Vision and Pattern Recognition(顶会)
在视觉上,注意力要么与CNN结合使用,要么用于替换CNN的某些组件。
本文证明,这种对CNNs的依赖是不必要的,直接将Transformer应用于图像patch序列,可以很好的完成分类任务。
Transformer已经成为自然语言处理(NLP)中的主要模型。
在CV领域,CNN仍然占据主导地位。
受NLP中Transformer的启发,本文尝试将标准Transformer直接应用于图像,并进行少量修改。
【重要结论】
当在中等大小的数据集(如ImageNet)上进行训练时,如果没有强大的正则化,这些模型(Vit,Transformer)的精度会比同等大小的ResNets低几个百分点。因为Transformer和CNN比较,缺乏一些固有的归纳偏置(inductive bias)。
如果数据集的规模足够大,ViT则会超越CNN,Transformer可以突破归纳偏置的限制。
补充:CNN中的归纳偏置
归纳偏置其实就是一种先验知识,一种提前做好的假设。
在CNN中的归纳偏置一般包括两类:①locality(局部性)和②translation equivariance(平移等变性)
① locality:假设相同的区域会有相同的特征,靠得越近的东西相关性能也就越强。局部性可以控制模型的复杂度。
②translation equivariance:由于卷积核是一样的所以不管图片中的物体移动到哪里,只要是同样的输入进来遇到同样的卷积核,那么输出就是一样的。利用平移等变形可以很好的提高模型的泛化能力。
参考论文:https://arxiv.org/abs/2010.08515
Transformer背景(注意力机制 计算开销);
CNN+注意力机制的背景;
相关模型:GPT (iGPT) (Chen et al., 2020a);
模型的设计尽可能遵循原Transformer,这样可以使模型更好的扩展到其他领域;
图1:模型概述。我们将图像分割成固定大小的patch,线性嵌入每个patch,添加位置embedding,并将生成的矢量序列输入到标准Transformer编码器。为了执行分类,本文使用了向序列中添加了“classification token”。
patch embedding
标准的Transformer输入为1D序列的token embedding。对于2D图像,将一个图像 x ∈ R H × W × C x ∈ R^{H × W × C} x∈RH×W×C reshape成一个展开的2D的patch x p ∈ R N × ( P 2 C ) x_p ∈ R^{N × (P^2C)} xp∈RN×(P2C),P:patch大小; N = H W / P 2 : p a t c h 数 量 N = HW/P^2:patch数量 N=HW/P2:patch数量。然后Linear embedding成D维度(patch_dim)的特征向量(patch embedding);
与BERT类似,在嵌入的patch序列之前添加一个class token,( z 0 0 = x c l a s s z_0^0=x_{class} z00=xclass),表示预测的类别;
position embedding
position embedding被添加到patch embedding以保留位置信息。本年使用标准的可学习1Dposition embedding。所得的嵌入向量序列用作编码器的输入。
输入到Transformer
经过Transformer的attention机制和FFN;
一个block之后维度依然和输入相同,都是197x768,因此可以堆叠多个block。
代码:
class ViT(nn.Module):
def __init__(self, *, image_size=224, patch_size=16, num_classes=3, dim=768, depth=4, heads=3, mlp_dim=3072, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
# Rearrage的意思是将传入的image(3,224,224),按照(3,(h,p1),(w,p2))也就是224=hp1,224 = wp2
# 接着把shape变成b (h w) (p1 p2 c)格式的
# 这样把图片分成了每个patch并且将patch拉长,方便下一步的全连接层
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
# Sequential(
# (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16)
# (1): Linear(in_features=768, out_features=768, bias=True)
# )
# 位置嵌入
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, img):
# 得到图像的embedding
x = self.to_patch_embedding(img)
b, n, _ = x.shape
# 在patch embedding前加上cls token
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
# position embedding不是concat 而是sum
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
# 输入到transformer 经过 Attention FFN
x = self.transformer(x)
# 池化
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
# 转换到类别
return self.mlp_head(x)
Vision Transformer比CNN具有更少的图像特定归纳偏置。在ViT中,只有MLP层是局部和平移不变的,自注意力机制是全局的。
在混合模型中,将224x224图片送入CNN得到16x16的特征图,拉成一个向量,长度为196,后续操作和ViT相同。
一种特殊情况,patch可以为1x1,则可以直接flatten成一个向量。
根据上文所述添加cls token嵌入和position embedding。
通常,在一个大型数据集上预训练ViT,然后在下游任务相对小的数据集上微调。
当图像的分辨率更高时,保持patch大小不变,这会导致有效序列变长。Transformer可以处理任意的序列长度,但是,预训练的位置嵌入可能不再有意义。
一个方法是使用插值算法,扩大位置编码表。但是如果序列长度变化过大,插值操作会损失模型性能,这是ViT在微调时的一种局限性。