Transformer论文解读三(distillation token)

最近Transformer在CV领域很火,Transformer是2017年Google发表的Attention Is All You Need中主要是针对自然语言处理领域提出的,后被拓展到各个领域。本系列文章介绍Transformer及其在各种领域引申出的应用。

本文介绍的Training data-efficient image transformers & distillation through attention将蒸馏应用于Transformer,在没有外部数据预选训练数据的情况下,可以在ImageNet上取得与目前的技术水平相当的结果。

Vision Transformer中提出的class token是一个可训练的向量,在第一层之前附加到patch token,该token经过Transformer层,然后用一个线性层进行投影,以预测类。因此,Transformer处理D维的(N+1)patch token时,仅使用class token来预测输出。这种体系结构迫使自注意在patch token和class token之间传播信息:在训练时,监督信号仅来自token embeding,而patch token是模型唯一的变量输入。

论文中给出的蒸馏过程如下:
Transformer论文解读三(distillation token)_第1张图片
其中包含了一个新的distillation token。它通过自我注意层与class token交互。这个distillation token与class token的使用方式类似,只是在网络的输出上,它的目标是再现教师预测的(硬)标签,而不是真标签。输入到Transformer的类和distillation token都是通过反向传播来学习的。

1. hard distillation versus soft distillation

soft distillation:软蒸馏使教师模型的softmax和学生模型的softmax之间的Kullback-Leiblersandu最小化。损失函数为:
在这里插入图片描述
其中 Z t Z_t Zt为老师模型的对数, Z s Z_s Zs为学生模型的对数。 T T T表示蒸馏温度, λ \lambda λ表示平衡Kullback-Leibler散度损失(KL)和交叉熵(LCE)的系数, ψ \psi ψ表示softmax函数。

hard distillation:论文中引入了一种蒸馏的变体,我们把老师的困难分类标签作为真实标签。与这个硬标签蒸馏相关的目标是:

在这里插入图片描述

其中 y t = a r g m a x c ( Z t ( c ) ) y_t= argmax_c(Z_t(c)) yt=argmaxc(Zt(c))是老师的困难分类标签。

对于给定的图像,与教师相关的硬标签可能会根据具体的数据增强而改变。我们将看到这种选择比传统的选择更好,同时没有参数,概念上更简单:教师预测与真实标签 y y y扮演相同的角色。

注意,硬标签也可以通过标签平滑转换为软标签,其中真标签的概率为 1 − ε 1-\varepsilon 1ε,其余的 ε \varepsilon ε在其余的类中共享。在所有使用真标签的实验中,论文中将这个参数设置为 ε = 0.1 \varepsilon=0.1 ε=0.1

2. classical distillation versus the distillation token

作者向初始embedding(patch和class token)添加一个新的token,即distillation token。distillation token与class token的使用类似:它通过自注意与其他embedding交互,并由最后一层之后的网络输出。蒸馏embedding允许我们的模型从老师的输出中学习,就像在常规蒸馏中一样,同时保持对一般embedding的补充。

3. joint classifiers

作者使用联合分类器对上述方法进行分类。在测试时,由Transformer产生的类或蒸馏embedding都与线性分类器相关联,并能够推断图像标签。参考方法是这两个分离头的后期融合,作者添加两个分类器的softmax输出来进行预测。

你可能感兴趣的:(计算机视觉,蒸馏,Transformer,transformer,深度学习,机器学习)