论文地址:https://arxiv.org/abs/2010.11929
代码:https://github.com/google-research/vision_transformer
发表于:ICLR 2021(Arxiv 2020.11)
虽然Transformer架构已经成为自然语言处理任务的事实标准,但它在计算机视觉方面的应用仍然有限。在视觉中,注意力机制要么与卷积网络一起使用,要么用来取代卷积网络的某些组件,在此同时保持卷积网络的整体结构。我们表明,这种对CNN的依赖是没有必要的,直接应用于图像块(image patch)序列的纯Transformer在图像分类任务上可以表现得非常好。当在海量数据上进行预训练后,转移到多个中型或小型图像识别基准测试(ImageNet、CIFAR-100、VTAB)时,与卷积网络SOTA相比,Vision Transformer(ViT)取得了出色的结果,同时训练时所需的资源也更少。
该网络的特点是,用的就是NLP里面原始的transformer,做到了几乎“开箱即用”,没有去引入太多额外的特别改进。事实上本文做的额外工作也主要体现在如何将图像转化为可以输入transformer的embedding向量。
Transformer一开始设计是用于NLP任务的。也就是说,首先将单词串(word seq)进行embedding后,得到一串一维的token,那么这串token就是Transformer的输入。也就是说,如果我们想把图片给输入进Transformer的话,需要想办法把图片给转换为一串token。那么在本文中,这个token就是图像块(image patch)。即,把图像切成一块一块,然后串联成一个token串:
形式化地讲,对于输入图像 x ∈ R H × W × C \mathbf{x} \in \mathbb{R}^{H \times W \times C} x∈RH×W×C,我们需要得到一个展平的2D图像块序列: x p ∈ R N × ( P 2 ⋅ C ) \mathbf{x}_{p} \in \mathbb{R}^{N \times\left(P^{2} \cdot C\right)} xp∈RN×(P2⋅C) 其中 ( H , W ) (H,W) (H,W)为原始图像的分辨率, C C C为通道数, ( P , P ) (P,P) (P,P)为图像块的分辨率, N = H W / P 2 N=H W / P^{2} N=HW/P2为所得到的图像块的个数,同时也代表着transformer输入序列的长度。
做到上面这一步,想当于完成了由图像到单词串的转换,那么接下来还要完成"单词串"的embedding过程。embedding可以暂时非常粗糙地理解为把输入内容投影至一个特定的向量空间内,以便神经网络进行运算处理。NLP中使用的是Word Embedding,而本文提出了Patch Embedding的概念,将图像块投影至D维的latent vector,有: z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E p o s , E ∈ R ( P 2 ⋅ C ) × D , E p o s ∈ R ( N + 1 ) × D \mathbf{z}_{0}=\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_{p}^{1} \mathbf{E} ; \mathbf{x}_{p}^{2} \mathbf{E} ; \cdots ; \mathbf{x}_{p}^{N} \mathbf{E}\right]+\mathbf{E}_{p o s}, \quad \mathbf{E} \in \mathbb{R}^{\left(P^{2} \cdot C\right) \times D}, \mathbf{E}_{p o s} \in \mathbb{R}^{(N+1) \times D} z0=[xclass ;xp1E;xp2E;⋯;xpNE]+Epos,E∈R(P2⋅C)×D,Epos∈R(N+1)×D
顾名思义,Position Embedding的作用是保持位置信息,这个也是Transformer中原有的概念而非本文所提出来的。接下来我们分析为什么要有这么个Position Embedding,原transformer论文中的自注意力机制可以表示为如下:
可以看到,在只用self-attention的情况下,第一个输入的词 x 1 x^{1} x1与最后一个输入的词 x 4 x^{4} x4是没有顺序上区别的,相当于一个大号词袋,因此需要对embedding得到的向量引入额外的位置信息,即所谓Position Embedding,如下所示:
图中一个个粉色的块为经Patch Embedding后,由图像块转换而来的latent vector;而椭圆旁的0、1、2、3等紫色的块即为通过Position Embedding后得到的额外位置信息。
在原文中,并没有给出Position Embedding的详细计算方法,只提到了是使用了一维embedding。通过查看代码可知其采用的是随机初始化参数再进行训练的方法:
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
x += self.pos_embedding[:, :(n + 1)]
注意这里与原transformer论文中使用的sin cos编码法有一定的区别。最终,经过Patch Embedding与Position Embedding得到的向量串就成为了Encoder的输入。至于Encoder-Decoder架构,由于直接使用的是原transformer论文中的实现,因此本文我们不对其进行介绍。
本文的最大贡献在于证明了使用纯Transformer便可以完成计算机视觉中的一些任务,(表面上)不需要再使用卷积。不过一个有意思的点是,在使用常规的中等大小数据集(ImageNet)来训练ViT的时候,性能上其实是比ResNet要低的。文中的解释是Transformer缺乏CNN的一些固有特性如平移不变形等,因此在训练数据不足时泛化能力较差。但是本文发现了如果继续加数据的话,就可以取得非常优秀的性能。可以认为ViT的性能一定程度上归功于大数据集(ImageNet-21k/JFT-300M),因此如何使用更少的数据来训练ViT可能也是今后的一大方向。
[1] http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Transformer%20(v5).pdf
[2] https://www.bilibili.com/video/av56239558/