Vision Transformer

1、前言

Transformer最初提出是针对NLP领域的,并且在NLP领域大获成功。这篇论文也是受到其启发,尝试将Transformer应用到CV领域。关于Transformer的部分理论之前的博文中有讲,链接,这里不在赘述。通过这篇文章的实验,给出的最佳模型在ImageNet1K上能够达到88.55%的准确率(先在Google自家的JFT数据集上进行了预训练),说明Transformer在CV领域确实是有效的,而且效果还挺惊人。

2、VIT 模型详解

模型分为三个模块

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder(图右侧有给出更加详细的结构)
  • MLP Head(最终用于分类的层结构)

Vision Transformer_第1张图片

Linear Projection of Flattened Patches(Embedding层)

图像像素有的很大,对于Transfomer 两两之间都有关联,所以是平方关系,如果像素很大,则平方就会很大,所以这里考虑的是将图片分成一个个patch,这里是16*16分为一个patch,相当于一句话中的一个单词。

比如224*224的图片

分为 14*14=196,每个patch 是16*16*3(通道数)=768

所以就为196*768  经过Linear Prijection 其实就是全连接 全连接矩阵参数 [768,tocken_dim]这里的tocken_dim也为768,最后经过矩阵运算 得到的最终维度是196*768[num_tocken,tocken_dim]

在输入Transformer Encoder之前注意需要加上[class]token以及Position Embedding。 

cls_tocken 是一个向量 1*tocken_dim ==>197*768

最后再加上一个位置信息,这里加上的是1D位置信息==》197*768.

在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。

Transformer Encoder

(图右侧有给出更加详细的结构)

Transformer Encoder其实就是重复堆叠Encoder Block L次,下图是我自己绘制的Encoder Block,主要由以下几部分组成:

Layer Norm,这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理,之前也有讲过Layer Norm不懂的可以参考  链接
Multi-Head Attention,这个结构之前在讲Transformer中很详细的讲过,不在赘述,不了解的可以参考 链接
Dropout/DropPath,在原论文的代码中是直接使用的Dropout层,在但rwightman实现的代码中使用的是DropPath(stochastic depth),可能后者会更好一点。
MLP Block,如图右侧所示,就是全连接+GELU激活函数+Dropout组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]

Vision Transformer_第2张图片


MLP Head详解

上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。注意,在Transformer Encoder后其实还有一个Layer Norm没有画出来,后面有我自己画的ViT的模型可以看到详细结构。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。
Vision Transformer_第3张图片

为了方便大家理解,我自己根据源代码画了张更详细的图(以ViT-B/16为例):

Vision Transformer_第4张图片

你可能感兴趣的:(transformer,transformer,深度学习,人工智能)