Mobile-ViT (MobileViT)网络讲解

目录

  • 前言
  • 一.Transformer
    • 1.1.Transformer存在的问题
    • 1.2.Vision Transformer
  • 二.Mobile-ViT
      • 2.1.MV2
      • 2.2.MobileViT
      • 2.3.模型配置

前言

  上篇博文我们分析了VIT的代码,有不了解的小伙伴可以去看下:Vision Transformer(VIT)代码分析——保姆级教程。这篇博文我们先介绍下Mobile-ViT的原理部分,代码分析我们下篇博文再介绍。下面附上论文和官方代码。

  • 论文连接:https://arxiv.org/abs/2110.02178
  • 官方代码:https://github.com/apple/ml-cvnets

一.Transformer

  在学习Mobile-ViT之前,建议各位小伙伴先学习下Transformer的知识点,不然直接看Mobile-ViT可能会有点吃力。关于Transformer的学习可以参考以下给出的视频和博文链接:

  • 【機器學習2021】自注意力機制 (Self-attention) (上)
  • 【機器學習2021】自注意力機制 (Self-attention) (下)
  • 【機器學習2021】Transformer (上)
  • 【機器學習2021】Transformer (下)
  • 详解Transformer中Self-Attention以及Multi-Head Attention
  • Vision Transformer详解

这里我们先简单的介绍下Transformer存在的问题以及我们常用的视觉Transformer网络ViT

1.1.Transformer存在的问题

  VIT及其变种性能都已经很强大了,那么我们为什么要有Mobile-ViT网络呢?根据Transformer的原理我们可以知道当前的Transformer模型主要存在以下问题:

  • 参数多,算力要求高
  • 缺少空间归纳偏置:即纯Transformer对空间位置信息不敏感,但是呢,我们在进行视觉应用的时候位置信息又比较重要,为了解决这个问题就引入了位置编码。
  • 不容易迁移到其他任务:这个问题核心还是引入的位置编码导致的,如ViT中采用的是绝对位置偏置,绝对位置编码的序序列长度是和输入序列的长度保持一致,即在训练模型的时候制定了输入图像的大小之后他的长度是固定的,这样在后面训练的时候如果改变了输入图像的大小,那么位置编码序列的长度和输入序列的长度就会不一致,导致无法训练。因此又出现了插值,即将绝对位置编码插值到与输入序列相同的长度。但是呢,插值也会出现问题,把插值的模型拿来直接用的话会出现模型性能下降的问题,如在 224 × 224 224\times224 224×224上训练, 512 × 512 512\times512 512×512上进行测试,理论上应该是性能会提升,但实际上会下降。因此插值之后还要对模型进行微调,这样就很麻烦。于是又出现了相对位置编码,如Swin Transformer,他只和window的大小有关。但是,如果迁移到其他任务,图像尺度相差比较大的情况下还是会进行微调,总之就是没有一个可以拿来直接用不用调整的模型。
  • 模型训练困难:需要更多的训练数据,更大的L2正则,更多的数据增强,更多的epoch,并且对数据增强还比较敏感。

  注意:模型的参数数量和推理时间没有什么关联。如下图,虽然MobileViT的参数比Mobilenetv2的参数少,但是推理时间还是远远大于Mobilebetv2
Mobile-ViT (MobileViT)网络讲解_第1张图片

1.2.Vision Transformer

  下图是作者给出的标准的ViT模型,仔细观察就会发现,这个ViT和我们平时见到的ViT有一点不同,就是他没有cls_token。有没有cls_token不重要,cls_token只是针对分类才加上去的,下面这个网络才是最标准的视觉ViT网络。
Mobile-ViT (MobileViT)网络讲解_第2张图片
  上面展示是标准视觉ViT模型,下面我们再来看下本次介绍的重点:Mobile-ViT网路结构,如下图所示:
Mobile-ViT (MobileViT)网络讲解_第3张图片
上面的网络的核心内容就是MV2Mobile ViT block模,下面我们来介绍下这两个模块。

二.Mobile-ViT

2.1.MV2

  MV2就是mobilenetv2里面Inverted Residual Block,即下面的图所示的结构,图中MV2是当stride等于1时的MV2结构,上图中标有向下箭头的MV2结构代表stride等于2的情况,即需要进行下采样。
Mobile-ViT (MobileViT)网络讲解_第4张图片

2.2.MobileViT

  MV2来源于mobilenetv2,所以Mobile-ViT的核心还是MobileViT这个模块。我们来分析下这个结构到底是什么,为什么他能减少模型参数量,提升模型的推理速度。
Mobile-ViT (MobileViT)网络讲解_第5张图片
  从上面的模型可以看出,首先将特征图通过一个卷积层,卷积核大小为 n × n n\times n n×n,然后再通过一个卷积核大小为 1 × 1 1\times 1 1×1的卷积层进行通道调整,接着依次通过UnfoldTransformerFold结构进行全局特征建模,然后再通过一个卷积核大小为 1 × 1 1\times 1 1×1的卷积层将通道调整为原始大小,接着通过shortcut捷径分支与原始输入特征图按通道concat拼接,最后再通过一个卷积核大小为 n × n n\times n n×n的卷积层进行特征融合得到最终的输出。这里有小伙伴可能会对folodunfold感到迷惑,所以这个地方的核心又落到了global representation部分(图中中间蓝色字体部分)。

Mobile-ViT (MobileViT)网络讲解_第6张图片
  我们以单通道特征图来分析global representation这部分做了什么,假设patch划分的大小为 2 × 2 2\times 2 2×2,实际中可以根据具体要求自己设置。在Transformer中对于输入的特征图,我们一般是将他整体展平为一维向量,在输入到Transformer中,在self-attention的时候,每个图中的每个像素和其他的像素进行计算,这样计算量就是:
P 1 = W H C P_{1}=WHC P1=WHC
其中,W、H、C分别表示特征图的宽,高和通道个数。
  在Mobile-ViT中的是先对输入的特征图划分成一个个的patch,再计算self-attention的时候只对相同位置的像素计算,即图中展示的颜色相同的位置。这样就可以相对的减少计算量,这个时候的计算量为:
P 2 = W H C 4 P_{2}=\frac{WHC}{4} P2=4WHC
  为什么可以这么做呢?简单理解一张图像的每个像素点的周围的像素值都差不多,并且分辨率越高相差越小,所以这样做并不会损失太多的信息。而且Mobile-ViT在做全局表征之前已经做了一次局部表征了(图中的蓝色字体)。
  我们再来介绍下unfoldfold到底是什么意思。unfold就是将颜色相同的部分拼成一个序列输入到Transformer进行建模。最后再通过fold是拼回去。如下图所示:
Mobile-ViT (MobileViT)网络讲解_第7张图片
关于foldunfoldpytorch代码实现可以参考下这篇博文:「详解」torch.nn.Fold和torch.nn.Unfold操作。现在我们再来看下网络的整体结构是不是容易理解多了。

在这里插入图片描述
  下面我们来简单的看下patch size对模型性能的影响,patch如果划分的比较大的话是可以减少计算量的,但是划分的太大的话又会忽略更多的语义信息,影响模型的性能。我们看下轮论文里面作者做的实验,下图从左到右对语义信息的要求逐渐递增。其中配置A的patch大小为{2, 2, 2},配置B的patch大小为{8, 4, 2},这三个数字分别对应下采样倍率为8,16,32的特征图所采用的patch大小。通过比较这三幅图可以发现,在图像分类和目标检测任务中,配置A和配置B在准确率和mAP上没多大区别,配置B要更快一些。但在语义分割任务中,配置A的效果要比较好。
Mobile-ViT (MobileViT)网络讲解_第8张图片
Mobile-ViT (MobileViT)网络讲解_第9张图片

2.3.模型配置

  论文中总共给出了三组模型配置,即MobileViT-SMobileViT-XSMobileViT-XXS,三种配置是越来越轻量化。
至此,关Mobile-ViT的理论部分基本上介绍完了,欢迎各位大佬批评指正。

你可能感兴趣的:(网络模型,网络,深度学习,计算机视觉)