作者丨科技猛兽
编辑丨3D视觉开发者社区
导读
本文提出了一种十分新颖的观点,即:输出知识蒸馏的潜力其实还没有得到完全开发。作者在本文中揭示出传统知识蒸馏方法会自然地抑制 NCKD 的作用,因此限制了知识蒸馏的潜力和灵活性。本文进一步将 TCKD 和 NCKD 进行解耦,通过独立的超参数控制二者的作用,得到的 DKD 蒸馏方法在一系列视觉任务上得到的明显的性能提升。
论文名称:Decoupled Knowledge Distillation
论文地址:
https://arxiv.org/pdf/2203.08679.pdf
(来自旷视科技,早稻田大学,清华大学)
现代知识蒸馏方法大多数注重深层的中间特征层面的知识蒸馏 (feature distillation),而对于相对而言比较原始的输出层面的知识蒸馏 (logit distillation) 的研究,因为性能不如前者而相对被忽略。本文提出了一种十分新颖的观点,即:输出知识蒸馏的潜力其实还没有得到完全开发。在本文中作者把 logit distillation 的输出分为两个部分,即:目标类别知识蒸馏 (target class knowledge distillation, TCKD) 和非目标类别知识蒸馏 (non-target class knowledge distillation, NCKD)。顾名思义,这两个名词的含义分别是指:对于模型输出中目标类别的值和非目标类别的值分别进行蒸馏。作者在这篇文章中指出:TCKD 传授给 “学生” 模型的知识是当前训练样本的难度,而 NCKD 才是知识蒸馏 work 的最主要原因。但是文章同样指出:常规的 KD 方法会 “抑制” NCKD 的作用,并且限制这两个部分的灵活度。因此本文提出将这两个部分进行解耦,分别完成对应的知识蒸馏操作。
本部分将对知识蒸馏的表达式进行重新推导。对于任意的训练样本,假设其类别属于第 t t t类,则其分类结果可以被写成: p = [ p 1 , p 2 , . . . , p t , . . . , p C ] ∈ R 1 × C {\mathbf{p}}=[p_1,p_2,...,p_t,...,p_C] \in\mathbb{R}^{1\times C} p=[p1,p2,...,pt,...,pC]∈R1×C。其中, C C C为样本的类别数, p i p_i pi为其属于第 i i i个类别的概率,通过softmax函数最终得到: p i = e x p ( z i ) ∑ j = 1 C e x p ( z j ) (1) p_i=\frac{{\rm exp}(z_i)}{\sum_{j=1}^C{\rm exp}(z_j)}\tag{1} pi=∑j=1Cexp(zj)exp(zi)(1)其中, z i z_i zi代表第 i i i个类别的output logits。
作者接下来定义属于目标类别的可能性参数 b = [ p t , p \ t ] ∈ R 1 × 2 {\mathbf{b}}=[p_t,p_{\backslash{t}}] \in\mathbb{R}^{1\times 2} b=[pt,p\t]∈R1×2,其中: p t = e x p ( z t ) ∑ j = 1 C e x p ( z j ) , p \ t = ∑ k = 1 , k ≠ t C e x p ( z k ) ∑ j = 1 C e x p ( z j ) (2) p_t=\frac{{\rm exp}(z_t)}{\sum_{j=1}^C{\rm exp}(z_j)},p_{\backslash{t}}=\frac{\sum_{k=1,k\ne t}^C{\rm exp}(z_k)}{\sum_{j=1}^C{\rm exp}(z_j)}\tag{2} pt=∑j=1Cexp(zj)exp(zt),p\t=∑j=1Cexp(zj)∑k=1,k=tCexp(zk)(2)同时,定义 p ^ = [ p ^ 1 , . . . , p ^ t − 1 , p ^ t + 1 , . . . , p ^ C ] ∈ R 1 × ( C − 1 ) {\mathbf{\hat p}} = [{\hat p}_1,...,{\hat p}_{t-1},{\hat p}_{t+1},...,{\hat p}_C]\in\mathbb{R}^{1\times (C-1)} p^=[p^1,...,p^t−1,p^t+1,...,p^C]∈R1×(C−1)为非目标类别的概率,就是在计算概率值时分母不包含第 t t t个类别: p ^ i = e x p ( z i ) ∑ j = 1 , j ≠ t C e x p ( z j ) (3) {\hat p}_i=\frac{{\rm exp}(z_i)}{\sum_{j=1,j\ne t}^C{\rm exp}(z_j)}\tag{3} p^i=∑j=1,j=tCexp(zj)exp(zi)(3)定义 T \mathcal{T} T和 S \mathcal{S} S分别为学生和教师模型,知识蒸馏使用KL-Divergence作为目标函数: K D = K L ( p T ∣ ∣ p S ) = p t T l o g ( p t T p t S ) + ∑ i = 1 , i ≠ t C p i T l o g ( p i T p i S ) \begin{align} \nonumber {\rm KD} &={\rm KL}({\mathbf{p}}^{\mathcal T}||{\mathbf{p}}^{\mathcal S})\\ &=p_t^{\mathcal T}{\rm log}(\frac{p_t^{\mathcal T}}{p_t^{\mathcal S}})+\sum ^C _{i=1,i \ne t}p_i^{\mathcal T}{\rm log}(\frac{p_i^{\mathcal T}}{p_i^{\mathcal S}}) \tag{4} \end{align} KD=KL(pT∣∣pS)=ptTlog(ptSptT)+i=1,i=t∑CpiTlog(piSpiT)(4)根据上式1和上式3,有 p ^ i = p i p \ t {\hat p} _i= \frac{p_i}{p_{\backslash{t}}} p^i=p\tpi,所以上式4可以被重写成: K D = p t T l o g ( p t T p t S ) + p \ t T ∑ i = 1 , i ≠ t C p ^ i T ( l o g ( p ^ i T p ^ i S ) + l o g ( p \ t T p \ t S ) ) = p t T l o g ( p t T p t S ) + p \ t T l o g ( p \ t T p \ t S ) ⏟ K L ( b T ∣ ∣ b S ) + p \ t T ∑ i = 1 , i ≠ t C p ^ i T l o g ( p ^ i T p ^ i S ) ⏟ K L ( p ^ T ∣ ∣ p ^ S ) \begin{align} \nonumber {\rm KD} &=p_t^{\mathcal T}{\rm log}(\frac{p_t^{\mathcal T}}{p_t^{\mathcal S}})+p_{\backslash{t}}^{\mathcal T}\sum ^C _{i=1,i \ne t}{\hat p}_i^{\mathcal T}({\rm log}(\frac{{\hat p}_i^{\mathcal T}}{{\hat p}_i^{\mathcal S}})+{\rm log}(\frac{p_{\backslash{t}}^{\mathcal T}}{p_{\backslash{t}}^{\mathcal S}}))\\ &=\underbrace {p_t^{\mathcal T}{\rm log}(\frac{p_t^{\mathcal T}}{p_t^{\mathcal S}})+p_{\backslash{t}}^{\mathcal T}{\rm log}(\frac{p_{\backslash{t}}^{\mathcal T}}{p_{\backslash{t}}^{\mathcal S}})}_{{\rm KL}({\mathbf b}^{\mathcal T}||{\mathbf b}^{\mathcal S})}+\underbrace {p_{\backslash{t}}^{\mathcal T}\sum ^C _{i=1,i \ne t}{\hat p}_i^{\mathcal T}{\rm log}(\frac{{\hat p}_i^{\mathcal T}}{{\hat p}_i^{\mathcal S}})}_{{\rm KL}({\mathbf {\hat p}}^{\mathcal T}||{\mathbf {\hat p}}^{\mathcal S})} \tag{5} \end{align} KD=ptTlog(ptSptT)+p\tTi=1,i=t∑Cp^iT(log(p^iSp^iT)+log(p\tSp\tT))=KL(bT∣∣bS) ptTlog(ptSptT)+p\tTlog(p\tSp\tT)+KL(p^T∣∣p^S) p\tTi=1,i=t∑Cp^iTlog(p^iSp^iT)(5)根据上式5可以看到,知识蒸馏损失函数可以视为是两部分 KL 散度之和:第1项是教师和学生模型关于目标类别二值概率的 KL 散度,称为目标类别知识蒸馏 (TCKD)。第2项是教师和学生模型关于非目标类别的 KL 散度,称为非目标类别知识蒸馏 (NCKD)。上式5可以被重写成: K D = K L ( b T ∣ ∣ b S ) + ( 1 − p t T ) K L ( p ^ T ∣ ∣ p ^ S ) (6) {\rm KD}={\rm KL}({\mathbf b}^{\mathcal T}||{\mathbf b}^{\mathcal S})+(1-p_t^{\mathcal T}){\rm KL}({\mathbf{\hat p}}^{\mathcal T}||{\mathbf{\hat p}}^{\mathcal S})\tag{6} KD=KL(bT∣∣bS)+(1−ptT)KL(p^T∣∣p^S)(6) K D = T C K D + ( 1 − p t T ) N C K D (7) {\rm KD}={\rm TCKD}+(1-p_t^{\mathcal T}){\rm NCKD}\tag{7} KD=TCKD+(1−ptT)NCKD(7)
如下图1所示是几种不同的模型 (ResNet8×4,ShuffleNet-V1等) 使用不同的损失函数得到的精度。直观上,TCKD 注重获得与目标类别相关的知识,因为相应的损失函数只考虑二元概率。相反,NCKD 注重获得非目标类别相关的知识。
TCKD 传递关于训练样本 “难度” 的知识:
根据上式5的第1项,TCKD 通过二元知识蒸馏任务传递了一些 “dark knowledge”,传递的知识是一个样本的目标类别的概率值的大小。比如,一个 p t T = 0.99 p_t^{\mathcal T}=0.99 ptT=0.99的样本比另外一个 p t T = 0.75 p_t^{\mathcal T}=0.75 ptT=0.75的样本更加 “容易”。当这个样本的目标类别的概率值更小时,TCKD 会更加有效。因为 CIFAR-100 训练集很容易拟合。因此,教师模型提供的知识是没有信息量的。在这一部分,作者从三个角度进行了实验来验证:训练数据越困难,TCKD 提供的帮助越大。
第一个角度是使用更强的数据增强手段。通过下图2的实验结果可以验证。
NCKD 是知识蒸馏起作用的最主要原因,但是它会被极大程度地抑制:
通过图1结果作者注意到,当仅应用 NCKD 时,实验的性能与经典 KD 相当甚至更好。这表明非目标类之间的知识至关重要,是 KD 方法 work 的主要原因。但是根据上式7可以发现,NCKD 这一项会乘以一项 ( 1 − p t T ) (1-p_t^{\mathcal T}) (1−ptT)这使得 NCKD 会被极大程度地抑制,从而无法充分发挥 KD 方法的性能。
针对这一观点作者做了个对比实验:根据 p t T p_t^{\mathcal T} ptT的大小将训练样本分为两个子集,分别使用 NCKD 蒸馏在每个子集上进行训练,实验效果如下图所示。可以发现, p t T p_t^{\mathcal T} ptT普遍较大 (0-50%) 时,NCKD 带来的性能增益就越多,证明原本 NCKD 被抑制的程度越高。
到目前为止,作者已经将经典的 KD 损失转化为两个独立部分的加权和,并进一步验证了 TCKD 的有效性,揭示了 NCKD 部分在正常使用 KD Loss 时会受到抑制。因此作者提出将这两个部分进行解耦,具体方法如下图6所示,伪代码如下图7所示。
作者在 CIFAR100 分类任务,ImageNet 分类任务,MS-COCO 目标检测任务上分别进行了实验。
CIFAR100 的实验结果如下图8所示。对于所有的教师-学生模型对,DKD 方法都获得了性能的提升。此外,DKD实现了与基于特征的蒸馏方法相当甚至更好的性能,显著改善了蒸馏性能和训练效率之间的平衡。
以 ResNet32x4 作为教师模型,ResNet8x4 作为学生模型,在 CIFAR100 数据集上的特征可视化结果如下图12所示。t-SNE 结果显示 DKD 的特征表示比 KD 更易区分,证明 DKD 有利于深层特征的可辨性。
如下图13所示为最先进的蒸馏方法的训练成本,证明了 DKD 的高训练效率。DKD 实现了模型性能和训练成本 (例如,训练时间和额外参数) 之间的最佳平衡。由于 DKD 是从经典的 KD 方法重新构造的,它只需要与 KD 几乎相同的计算复杂度,当然没有额外的参数。然而,基于特征的提取方法需要额外的提取中间层特征的训练时间,以及 GPU 的存储成本。
本文提出了一种十分新颖的观点,即:输出知识蒸馏的潜力其实还没有得到完全开发。在本文中作者把 logit distillation 的输出分为两个部分,即:目标类别知识蒸馏 (target class knowledge distillation, TCKD) 和非目标类别知识蒸馏 (non-target class knowledge distillation, NCKD)。顾名思义,这两个名词的含义分别是指:对于模型输出中目标类别的值和非目标类别的值分别进行蒸馏。作者在本文中揭示出传统知识蒸馏方法会自然地抑制 NCKD 的作用,因此限制了知识蒸馏的潜力和灵活性。本文进一步将 TCKD 和 NCKD 进行解耦,通过独立的超参数控制二者的作用,得到的 DKD 蒸馏方法在一系列视觉任务上得到的明显的性能提升。
版权声明:本文仅做学术分享,未经授权请勿二次传播,版权归原作者所有,若涉及侵权内容请联系删文。
3D视觉开发者社区是由奥比中光给所有开发者打造的分享与交流平台,旨在将3D视觉技术开放给开发者。平台为开发者提供3D视觉领域免费课程、奥比中光独家资源与专业技术支持。
点击加入3D视觉开发者社区,和开发者们一起讨论分享吧~
也可移步微信关注官方公众号 3D视觉开发者社区 ,获取更多干货知识哦!