这篇文章主要是通过一些训练策略和知识蒸馏
来提升模型的训练速度和性能效果。
原文链接:Training data-efficient image transformers & distillation through attention
源码地址:https://github.com/facebookresearch/deit
写的好的文章:Transformer学习(四)—DeiT
知识蒸馏可以简单看下这篇介绍:知识蒸馏(Knowledge Distillation) 经典之作,论文笔记
虽然ViT在分类任务中有着非常高的性能,但它是使用大型基础设施预先训练了数亿张图像,才得到现在的效果,这两个条件限制了其应用。
因此作者提出了一种新的训练策略
,在不到3天的时间内仅使用一台计算机在ImageNet上训练出具有竞争力的无卷积Transformer。在没有外部数据的情况下,在ImageNet上达到了83.1%((86M参数))的最高精度。
其次,作者提出了一种基于知识蒸馏的策略
。依赖于一个蒸馏token,确保student 模型通过注意力从teacher 模型那里学习,通常老师模型是基于卷积的。学习到的Transformer与ImageNet上的最先进技术具有竞争力(85.2%)。
近来人们对利用convnet中注意力机制的体系结构越来越感兴趣,提出了混合结构,将transformer成分移植到ConvNet以解决视觉任务。在本文作者使用的是纯Transformer结构,但是在知识蒸馏策略中,使用convnet网络作为teacher网络来训练,能够继承到convnet中的归纳偏置。
ViT模型使用的是包含3亿张图像的大型私有标记图像数据集,才能达到最好的效果,同时也得到结论:在数据量不足的情况下训练时不能很好地概括。
在本文中,作者在一个8GPU节点上用两到三天的时间(53小时的预训练,以及可选的20小时的微调)训练视觉Transformer,这能够与具有相似数量的参数和效率的ConvNet相竞争。使用Imagenet作为唯一的训练集。
作者提取模型时还使用了一种基于token的蒸馏策略,文中⚗作为蒸馏标志。
概括一下有以下贡献:
不包含卷积层
,在没有外部数据的情况下,可以在ImageNet上实现与最先进技术相比的竞争结果。两个新模型变体DeiT-S和DeiT-Ti的参数更少,可以看作是ResNet-50和ResNet-18的对应物。估计标签
。这两个token通过注意力在transformer中进行交互。转移到不同的下游任务
时具有竞争力。训练策略:
在较低的分辨率下训练,并在较大分辨率下微调网络,这加快了完整训练的速度,并提高了在主流数据增强方案下的准确性。
当增加输入图像的分辨率时,保持patch大小不变,因此输入patch的数量N会发生变化。由于transformer块和class token的架构,不需要修改模型和分类器来处理更多token。而是需要调整位置嵌入,因为每个patch一个,共有N个位置嵌入。
蒸馏: 首先假设可以使用强大的图像分类器作为教师模型。它可以是convnet,也可以是分类器的混合。本节介绍:硬蒸馏与软蒸馏,以及蒸馏token。
首先借用一个知乎小皇帝的图,teacher模型是拥有更大体量和优越效果的已知模型,在蒸馏过程中,teacher模型是不进行训练的,只是作为一种指路标杆来引导图像找到teacher模型中对应我们需要的参数。实际上,我们只是利用了teacher模型映射过程中产生的别的信息。在普通的分类模型训练,我们有的信息只有图像和分类标签,如果是该类,就是1,不是就是0。但是teacher模型训练过程中经过softmax函数得到不同类别的概率,我们就是利用这个概率分布来训练student模型,除了正样本,负样本中也包含非常多的信息,但是Ground Truth并不能提供这部分信息,而teacher模型的概率分布相当于在student模型训练时增加了部分新的标签信息。更详细的内容可以看这个链接:知识蒸馏
1. 软蒸馏:
最小化教师模型的softmax和学生模型的softmax之间的Kullback-Leibler散度。假设Zt是教师模型的logits,Zs是学生模型的logits。用τ表示蒸馏温度,λ表示平衡地面真值标签y上的Kullback-Leibler发散损失(KL)和交叉熵(LCE)的系数,ψ表示softmax函数。蒸馏的目标是:y的部分普通的loss计算,后半部分是散度。
2. 硬蒸馏变体:
将Teacher模型的预测输出 y t = a r g m a x c Z t ( c ) y_t = argmax_cZ_t(c) yt=argmaxcZt(c)作为真实标签,对于给定的图像,与教师相关联的硬标签 y t y_t yt可能会根据特定的数据增加而变化。这种选择优于传统选择,同时无参数且概念更简单:教师预测 y t y_t yt与真正的标签y起着相同的作用。蒸馏目标为:
3. 蒸馏token:
在初始嵌入(patch和class token)中添加了一个新token,即蒸馏token。蒸馏token与class token类似:它通过自注意力与其他嵌入交互,并在最后一层之后由网络输出。蒸馏嵌入允许模型学习教师模型的预测输出,不仅学习到教师模型的先验知识,同时也是对class嵌入的补充。
蒸馏策略:
微调:
在更高分辨率的微调阶段使用真实标签和教师预测。使用具有相同目标分辨率的教师模型,测试阶段仅使用真正的标签。联合分类器
。在测试时,transformer生成的类或蒸馏嵌入都与线性分类器关联,并且能够推断图像标签。将这两个独立头在后期进行融合,添加两个分类器的softmax输出以进行预测。实验部分中写了不同蒸馏方法的结果比较。
这篇文章本身并没有对ViT模型进行改进,只是使用了一些训练策略,使之更容易训练,同时也提高了模型的性能。
文中核心是使用了知识蒸馏的策略,增加了模型训练过程中的负样本的预测信息,继承了teacher模型(convnet的效果优于Transformer)中的归纳偏置,实际上是对标签信息的一种补充。
最后祝各位科研顺利,身体健康,万事胜意~