Vision Transformer模型/论文详解

Vision Transformer 模型详解

  • 论文链接: https://arxiv.org/abs/2010.11929
  • 论文原文代码:https://github.com/google-research/vision_transformer
  • timm链接:https://github.com/rwightman/pytorch-image-models

Vision Transformer模型/论文详解_第1张图片


文章目录

  • Vision Transformer 模型详解
    • 1. 前言:
    • 2. 模型结构
      • 1. patch embedding部分
        • 一、patch embeding的具体流程
        • 二、 图像数据转化为transoformer所需要的序列
        • 三、 position embeding
        • 四 cls token
      • 2. transformer encoder部分
      • 3. MLPHead部分
    • 3.归纳偏置
    • 4.混合模型
    • 5. 训练和微调
    • 6. VIT模型实验展示和对比图

1. 前言:

​ 作者团队证明了脱离神经网络,使用一个纯的transformer结构也能在图像分类任务上表现的很好。甚至当我们在大规模数据集(google 的JTF300M或者Imagenet21K)上进行预训练然后在小数据集上进行微调时,它的表达效果甚至超过了传统的卷积神经网络。并且随着数据集的扩大vision transformer还没有出现过饱和的现象。

2. 模型结构

作者将transformer结构去除掉decoder部分后经过改进将其运用到视觉领域。vision transformer模型结构主要包括三部分

  • patch embeding 部分

  • transformer encoder部分

  • MLP head部分

论文中的模型框架图:

Vision Transformer模型/论文详解_第2张图片

先了解, 熟悉一下训练流程(后面会对每一个部分进行解释):

  • 图片先被切分成为若干个图像块,然后将图像块经过一个线性映射层后得到patch embed
  • 将得到的patch embed与cls token 进行concate.
  • 将concate后的结果与position embeding直接相加
  • 将得到的"序列"先进行Layer Norm归一化再经过Multi-Head Attention, Layer Norm, MLP后得到第一个Encoder的结果
  • 一次经过L个encoder后将cls token送入MLP Head 后得到网络输出的分类结果
  • 再与真实的label计算Cross Entropy Loss后进行反向传播

1. patch embedding部分

一、patch embeding的具体流程

patch embeding将一个CHW格式的图片数据处理成transformer所需要的num_patchs x embed_dim格式的序列,首先将一张图片拆分为若干个p x p x 3大小的图像块(p为patch size),再对每个图象块经过线性映射层产生transformer所需要的序列,以ViT-B/16为例看下图(自己画的)

Vision Transformer模型/论文详解_第3张图片

在代码中则是直接采用kernelsize为p(ViT-B/16中为16),stride为也为p,outchannel为pxpx3的卷积来进行图像切分和线性映射的实现,卷积后输出的结果形状为(n, embed_dim, HW/p^2, HW/p^2), 在VIT-B/16中为14,再对卷积产生的结果进行拉直,换位。再concatecls token后与position embeding对应元素相加得到patch embeding最后的结果。过程中的cls tokenposition embeding 请看下文。

二、 图像数据转化为transoformer所需要的序列

对于一个标准的transformer需要的输入是一个二维的序列,而我们的图片则是一个由像素组成的三维向量,那么如何将一张图片的数据转化为transformer中的序列呢?

作者在论文中提到:

  1. 直接使用原图片的所有像素构成一个序列输入到transformer中,这种方法是不行的,因为transformer的复杂度为n平方,一般分类任务图片大小为224x224,所需要的计算时间开销过大
  2. Stand-Alone Attention(孤立自注意力): 利用窗口的方式,使用窗口内的像素来代替整张图片的像素,减少序列长度。
  3. Axial Attention(轴自注意力):先对图片的H特征轴进行一次自注意力操作,再对图片的W特征轴进行注意力操作,降低生成的序列长度,解决时间开销问题。
  4. 将图片切分为若干个图片块,再将每一个图片块中的像素转化为序列进行操作。

vision transformer使用了第四种方式,第2,3种方式由于操作的特殊性无法在现代硬件技术上进行加速计算,无法构建大的网络模型来解决问题。

三、 position embeding

作者参照标准的transformer,采用与卷积不同的方式,对图片切分后产生的图像块进行编码,通过构建一个可学习的position embeding向量来表示图片的空间信息并与原有的数据逐元素相加,这里作者分别实验了几种不同的编码方式和不采用position embeding分别进行对比:

    1. 采用1D的编码方式(即直接从左到右,从上到下进行编码,为论文模型框架图所用的一致)
    1. 采用2D的编码方式(即对于分别对H,W轴进行编码)
    1. 采用相对位置信息进行编码(计算不同图象块之间的位置距离)

作者还实验了分在进入Encoder前加入position embeding, 在每层Encoder前都加入position embeding, 添加一个可学习的position embeding到每层网络的开始。

最终效果如图:

Vision Transformer模型/论文详解_第4张图片

从中可以看出不同的编码方式产生的效果差别不大,而不采用位置编码则效果不好,同时作者推测造成这种不同的编码方式产生近乎相同的效果的原因是:ViT将图片划分成了图象块,这些图像块的数量相对与原来图像中的像素点的数量来说很少,所以他们之间的位置信息很容易学习得到。甚至对于不同的超参数产生的学习产生的效果也很相似(如下图)

Vision Transformer模型/论文详解_第5张图片

四 cls token

作者为了尽量与原来的transformer保持一致,也构造了一个cls token向量,与原来的patch embeding堆叠到一起(concate)用来表示网络学习得到的结果,最后只将cls token 送入MLP Head来产生分类结果,同时还实验了不采用cls token,直接对最后一层输出的token 进行全局平均池化产生的效果也都差不多。

Vision Transformer模型/论文详解_第6张图片

2. transformer encoder部分

由于每个encoder layer并没有改变每个序列(token)的长度,所以直接由多个encoder layer堆叠起来组成了transformer encoder部分,而encoder layer部分主要由MSA(Multi-Head-Attention ), LN(Layer-Norm)和MLP组成

一个MSA是由多个self-Attention多用在同一个token上组成,这里给出百度飞桨朱欤老师所画的图来解释说明课程链接

Vision Transformer模型/论文详解_第7张图片
Vision Transformer模型/论文详解_第8张图片

对每一个图像块形成的image token分别由三个线性层得到q, k, v 三个部分其中:

k:表示token的key值

q:query分别对其他的image token产生的key值进行对应元素相乘来得到s,在经过scale归一化和softmax之后得到Attention的权重矩阵(Attention Weight)

V:表示token经过线性层后提取出来的信息,与Attention Weight相乘后得到self-Attention的结果

个人对于self-Attention的理解:

  • 通过产生的q和k相乘后进行scale和softmax归一化来计算不同位置图片块之间的相似度从而构建了Mask(掩码)的方式给予一个特征权重,相似度大的部分权重越大,相似度小的部分权重则越小

  • 再使用这个特征的权重矩阵与原来的token经过线性层后提取出来的特征对应元素相乘,增加模型对相似度高的部分的表征能力(也叫做注意力)

而MSA则是分别使用多个self-Attention对token进行处理(注意:每个self-Attention作用在token的不同地方没有重叠,就是将输入的token平均分成了若干个部分给self-Attention来处理), 再对每一个self-Attention处理完成后的结果按顺序组合拼接即可。

Vision Transformer模型/论文详解_第9张图片
(这张图也来自于飞桨朱欤老师)

这部分没理解的,具体可以看另外一位大佬写的博客:链接

Layer Norm相比于Batch Norm不同的地方在于Batch Norm是对卷积神经网络中间的特征图每一层做归一化,而Layer Norm则是对整个token(相当与卷积神经网络特征图的整体)做归一化,在框架中都有相应的api来实现

MLP层实际上是由前后两层线性层中间加上GELU激活函数和DropOut层组成,这里需要注意的是中间的两个线性层中的第一个将token的长度放大为原来的4倍,后一个线性层则将长度重新恢复。

3. MLPHead部分

前面所述ViT为了参照标准的transformer只将cls token作为模型最终的处理结果来送入MLPHead中。而MLP Head中包含了两种结构,当进行预训练的时候,MLPHead部分有两个线性层,而在中小型数据集上微调时则为一个线性层。

Vision Transformer模型/论文详解_第10张图片

注意:在经过transformer encoder之后的token先进行了一次Layer Norm最后才送入MLP Head, 在论文框架图中并未给出

3.归纳偏置

​ 作者经过实验发现transformer在中小型数据集上的表现效果实际上是不如卷积神经网络的,但是如果将transformer在大规模数据集上进行预训练然后在中小数据集进行微调时效果超过了卷积神经网络。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hLkID8Oe-1642319897818)(/home/shier/Pictures/ViT/Inductive.png)]

论文中表示造成这种结果的原因是卷积神经网络存在归纳偏置(Inductive Bias),这种归纳偏置具有平移等变性(translation equivariance)、局部性(locality)

translation equviriance : 无论是先做卷积还是先做平移最终将产生相同的结果

locality:图片在做卷积时,因为步长的原因,中间存在有重叠的部分,相邻的区域会有相同的特征

个人理解(仅供参考,可能存在错误):

当存在两张图片,图片中物体相同而处于图片不同位置时,卷积神经网络因为归纳偏置产生的效果相似,而transformer需要学习不同patch之间的位置信息,学习的困难程度将会增加。

4.混合模型

由于transformer中缺乏卷积神经网络中的归纳偏置, 作者提出可以选择一个卷积神经网络模型来当作transformer的特征提取器,对卷积神经网络模型产生的特征图进行patch embeding。论文中选择了resnet50进行实验,此时的resnet50是进行了修改之后的resnet50。将模型中的Batch Norm层替换成了Group Norm, 传统的卷积替换成了stdconv, 并且去除了resnet50中的stage4,将stage4中的结构添加到了stage3中来当作ViT的特征提取器。

5. 训练和微调

由于vision transformer中图片的分辨率大小和patch size的大小会影响patch embeding之后的序列长度,就是说如果想直接在中小型数据集上进行微调时扩大输入图片的像素来增加训练效果是不行的,这会增大原有的序列长度,导致预训练的成果失效,不过我们可以在原有的序列上进行2D的线性插值来应对这个问题,不过这种解决方式还是会造成效果降低。

6. VIT模型实验展示和对比图

Vision Transformer模型/论文详解_第11张图片

Vision Transformer模型/论文详解_第12张图片

Vision Transformer模型/论文详解_第13张图片

参考:

csdn霹雳吧啦_wz:https://blog.csdn.net/qq_37541097/article/details/118242600

csdn霹雳吧啦_wz:https://blog.csdn.net/qq_37541097/article/details/117691873

知乎:https://zhuanlan.zhihu.com/p/356155277

飞桨:https://aistudio.baidu.com/aistudio/education/group/info/25102

官方论文: https://arxiv.org/abs/2010.11929

论文代码:https://github.com/google-research/vision_transformer

哔哩哔哩视频:https://www.bilibili.com/video/BV15P4y137jb?spm_id_from=333.1007.top_right_bar_window_custom_collection.content.click


本贴写于:2022年1月16日。完。

你可能感兴趣的:(深度学习,transformer,深度学习,pytorch)