Transformer论文解读二(Vision Transformer)

最近Transformer在CV领域很火,Transformer是2017年Google发表的Attention Is All You Need中主要是针对自然语言处理领域提出的,后被拓展到各个领域。本系列文章介绍Transformer及其在各种领域引申出的应用。

本文介绍的An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale将Transformer应用到计算机视觉领域,称作Vision Transformer(ViT)。

Vision Transformer将图像分割成固定大小的小块,线性嵌入每个小块,添加位置编码,并将得到的向量序列输入标准Transformer编码器,主要包括三个模块:

  1. Linear Projection of Flattened Patches(嵌入层)
  2. Transformer Encoder(如下图右侧所示)。作用是将图片格式[H, W, C]转化为Transformer输入格式[num_token, token_dim]。
  3. MLP Head(最终用于分类的层结构)

Transformer论文解读二(Vision Transformer)_第1张图片

1. Embedding层结构

首先将一张图片按给定大小分成一堆Patches,接着通过线性映射将每个Patch映射到一维向量中,输入嵌入层,将输入图片的格式[H, W, C]变换为标准的Transformer模块的输入格式token(向量)序列,即二维矩阵[num_token, token_dim]。这一嵌入操作在代码实现中简单地通过卷积层和展平操作来实现。

在输入Transformer Encoder之前,需要加上token以及位置编码。 在刚刚得到的一堆tokens中,插入一个专门用于分类的[class]token,是一个可训练的参数。位置编码与Transformer中讲到的位置编码一致,采用的是一个可训练的参数,是直接加在tokens上。

Transformer论文解读二(Vision Transformer)_第2张图片
对于位置编码,作者做了一系列对比试验,在源码中默认使用的是1D Pos. Emb.,对比不使用位置编码准确率提升了大概3个点:

Transformer论文解读二(Vision Transformer)_第3张图片

2.Transformer Encoder

Transformer Encoder的结构主要由以下几部分组成:

  1. Layer Norm,对每个token进行Norm处理
  2. Multi-Head Attention
  3. Dropout/DropPath。
  4. MLP Block,由全连接+GELU激活函数+Dropout组成。需要注意的是第一个全连接层会把输入节点个数翻4倍,第二个全连接层会还原回原节点个数。

Transformer论文解读二(Vision Transformer)_第4张图片

3. MLP head

通过Transformer Encoder后输出和输入的形状是保持不变的。这里只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行。接着通过MLP Head得到最终的分类结果。论文中指出,在训练ImageNet21K时MLP head是由Linear+tanh激活函数+Linear组成,迁移到ImageNet1K上时,只用一个Linear即可。

Transformer论文解读二(Vision Transformer)_第5张图片

4. Hybrid模型

论文中还提出了Hybrid混合模型,就是将传统CNN特征提取和Transformer进行结合。例如采用Resnet50为backbone时,一些变化如下:

  1. 卷积层采用的StdConv2d不是传统的Conv2d,所有的BatchNorm层也替换成GroupNorm层。
  2. 堆叠次数也有所不同。在原Resnet50网络中,stage1重复堆叠3次,stage2重复堆叠4次,stage3重复堆叠6次,stage4重复堆叠3次,但在这里的R50中,把stage4中的3个Block移至stage3中,所以stage3中共重复堆叠9次。

通过R50 Backbone进行特征提取后,得到的特征矩阵再输入嵌入层,这时嵌入中卷积层Conv2d的kernel_size和stride都变成了1,只是用来调整channel。后面的部分和正常ViT一致。

参考文献
Vision Transformer详解

你可能感兴趣的:(计算机视觉,Transformer,分类,transformer,计算机视觉,深度学习)