论文:Training data-efficient image transformers & distillation through attention
代码:https://github.com/facebookresearch/deit
性能对比:top-1准确率 vs. 网络吞吐量(仅在ImageNet1k上训练)——使用transformer专用蒸馏方法训练的模型最优。
回顾原始ViT的原理:
主要介绍了软蒸馏、硬蒸馏两种损失函数,和Distillation token结构。
L g l o b a l = ( 1 − λ ) L C E ( ψ ( Z s ) , y ) + λ τ 2 K L ( ψ ( Z s / τ ) , ψ ( Z t / τ ) ) \mathcal{L}_{global}=(1-\lambda)\mathcal{L}_{CE}(\psi(Z_s),y)+\lambda\tau^2\mathrm{KL}(\psi(Z_s/\tau),\psi(Z_t/\tau)) Lglobal=(1−λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ))
λ , τ \lambda,\tau λ,τ是超参数, y y y是ground truth, ψ \psi ψ是softmax函数, Z s , Z t Z_s,Z_t Zs,Zt分别是学生模型、教师模型的输出, L C E \mathcal{L}_{CE} LCE是交叉熵损失, K L \mathrm{KL} KL是KL散度。
L g l o b a l h a r d D i s t i l l = 1 2 L C E ( ψ ( Z s ) , y ) + 1 2 τ 2 L C E ( ψ ( Z s ) , y t ) \mathcal{L}_{global}^{hardDistill}=\frac12\mathcal{L}_{CE}(\psi(Z_s),y)+\frac12\tau^2\mathcal{L}_{CE}(\psi(Z_s),y_t) LglobalhardDistill=21LCE(ψ(Zs),y)+21τ2LCE(ψ(Zs),yt)
y t y_t yt是教师模型的预测结果。
如图,在patches中加入与class token类似的distillation token,两者的通过网络时的计算方式相同,区别在于class token目标是重现ground truth标签,而distillation token目标是重现教师模型的预测。
输出时的distillation token与class token余弦相似度为0.93,表明两者的目标相似但不相同。
当用一个class token替换distillation token时,两个class token输出的余弦相似度为0.999,网络性能与一个class token相近,而加入distillation token的网络性能明显提升。这表明distillation token的设定是有效的。
对蒸馏的结构进行微调,需要将教师网络的目标分辨率提升。
分类结果由class和distillation输出的softmax之和决定。
定义了与ViT-B参数相同的DeiT-B模型,和更小的DeiT-S、DeiT-Ti模型,超参数如下:
实验发现RegNetY-16GF是效果最好的教师模型,后续实验默认选择。卷积网络教师优于transformer教师,可能因为继承了卷积网络的bias。
硬蒸馏优于软蒸馏,class和distillation token同时使用优于单独使用一个。
下表为不同分类器中分类结果不同的比例。结果表明,使用distillation embedding的分类器结果与卷积网络更相似,使用class embedding的分类器结果与无蒸馏的DeiT更相似,两者结合的分类器结果介于两者之间。
300epochs后使用distillation token的网络已经占优,且性能仍未饱和,继续训练可以提升准确率。
使用timm库的实现,对比DeiT、ViT和卷积网络的准确率和效率(吞吐量)。
在其他数据集的预测表现:
消融实验:adamw优化器优于SGD,各种数据增强方法几乎都有效(除了dropout)。
DeiT对参数初始化相对敏感,使用截断的正态分布进行初始化。DeiT对优化超参数很敏感。
ViT-B和DeiT-B的训练超参数如下表: