All Tokens Matter: Token Labeling for Training Better Vision Transformers

All Tokens Matter: Token Labeling for Training Better Vision Transformers

[论文地址]([2104.10858] All Tokens Matter: Token Labeling for Training Better Vision Transformers (arxiv.org))

[代码地址](zihangJiang/TokenLabeling: Pytorch implementation of “All Tokens Matter: Token Labeling for Training Better Vision Transformers” (github.com))

摘要

与传统ViT通过额外添加的可训练的Class token计算分类损失的目标不同。本文提出了一种新的用来提高ViT性能的训练目标token labeling。该方法通过一个注释器,为每个token提供一个监督标注,并基于这些监督标注计算每一个token的损失。对于26M可学习参数的ViT,通过token labeling在ImageNet可以实现84.4的top1准确率。随着模型规模的扩大,可以达到最高86.4%的准确率。作者还通过实验证明了token labeling可以提高预训练模型对密集预测任务(语义分割)的鲁棒性。

方法思路:

传统的ViT仅利用class token进行分类预测,虽然充分利用了模型提取的全局信息,却缺乏一定的局部信息。patch token中包含大量的局部有用信息。通过利用class token和patch token可以更好地提高模型性能。本文通过注释器为生成N×K维的分数图作为监督,密集监督所有token,K表示分类类别数,N是输出的token数目。这样,每个patch便有了一个仅与位置相关的监督,表示在某个patch上是否含有目标。本文首次证明通过密集监督有利于ViT的视觉分类任务。

与其他ViT模型的性能比较:

All Tokens Matter: Token Labeling for Training Better Vision Transformers_第1张图片

方法

模型

All Tokens Matter: Token Labeling for Training Better Vision Transformers_第2张图片

Token Labeling

标准ViT仅利用cls token用于分类损失计算,表示为:Lcls = H(Xcls, ycls)

Token Labeling通过计算每个所有输出token对损失进行密集计算,所以,该方法损失不仅涉及一个K维的向量ycls,还包括K×N的分数矩阵,N表示patch块的数目,该部分损失表示为:在这里插入图片描述

总损失:
All Tokens Matter: Token Labeling for Training Better Vision Transformers_第3张图片

其中β时是平衡两项的超参数,经实验,本文将其设置为0.5。

优点

  • 与知识蒸馏应用教师模型在线生成监督标签不同,令牌标记相对廉价
  • 通过利用分数图进行密集监督,时每个patch token的标签提供特定于位置的信息可以帮助提高精度
  • 由于密集监督的引入,有利于密集预测的下游任务如语义分割

Token Labeling with MixToken

数据增强对于ViT增强模型的性能和鲁棒性是非常必要的。然而直接在原始图像上进行CutMix可能导致最终的输出token混合了两幅图片中的信息,这使得最终的token获取的标记可能不是干净准确的如下图left红色patch所示。为此,本文提出了基于token的增强方法MixToken,如下图right所示,其实质是基于token的CutMix方法。

All Tokens Matter: Token Labeling for Training Better Vision Transformers_第4张图片

具体来说,首先将两幅图像经过patch embedding转化为两个token序列T1,T2。新的Token序列(T_hat)通过一个Mask混合,相应的token标签(Y_hat)通过同一个Mask混合,生成Mask的方法参照[CutMix]([1905.04899] CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (arxiv.org))中。其中类标记的标签(ycls_hat)如图公式三所示。

All Tokens Matter: Token Labeling for Training Better Vision Transformers_第5张图片

实验

不同版本的LV-Vit性能

All Tokens Matter: Token Labeling for Training Better Vision Transformers_第6张图片

增强方法的消融

All Tokens Matter: Token Labeling for Training Better Vision Transformers_第7张图片

Token 参与比例(左)不同的注释器(右)对性能影响

All Tokens Matter: Token Labeling for Training Better Vision Transformers_第8张图片

在三种不同的ViT上应用token labeling的性能提升

All Tokens Matter: Token Labeling for Training Better Vision Transformers_第9张图片

分类任务上的性能比较

All Tokens Matter: Token Labeling for Training Better Vision Transformers_第10张图片

语义分割任务上的性能

All Tokens Matter: Token Labeling for Training Better Vision Transformers_第11张图片

你可能感兴趣的:(论文阅读,深度学习,计算机视觉,pytorch)