论文地址:https://arxiv.org/abs/2211.16231
代码地址:https://github.com/zhengli97/CTKD
研究动机:大多数现有的蒸馏方法忽略了温度在损失函数中的灵活作用,将其固定为超参数。一般而言,温度控制着两种分布之间的差异,确定蒸馏任务的难易程度。保持一个恒定的温度,即固定的任务难度,在渐进学习阶段通常是次优的。
研究思路:本文提出了一种简单的基于课程的技术,称为知识蒸馏课程温度( CTKD ),它是一个动态温度超参蒸馏新方法。具体来说,遵循由易到难的课程设置,随温度的变化逐渐增加蒸馏损失,以对抗的方式导致蒸馏难度的增加。
本文的主要工作:
传统的两段蒸馏过程通常以预先训练的繁琐的教师网络开始。然后在教师网络的监督下以soft预测或中间表示的形式训练一个紧凑的学生网络。采用带有温度超参的KL Divergence Loss散度损失最小化学生和教师模型的soft输出概率差异,从而在教师模 型和学生模型之间进行蒸馏, 公式如下:
L k d ( q t , q s , τ ) = ∑ i = 1 I τ 2 K L ( σ ( q i t / τ ) , σ ( q i s / τ ) ) L_{k d}\left(q^t, q^s, \tau\right)=\sum_{i=1}^I \tau^2 K L\left(\sigma\left(q_i^t / \tau\right), \sigma\left(q_i^s / \tau\right)\right) Lkd(qt,qs,τ)=i=1∑Iτ2KL(σ(qit/τ),σ(qis/τ))
其中, q t , q s q^t, q^s qt,qs分别表示教师和学生产生的logit, σ ( ⋅ ) \sigma ( · ) σ(⋅)为softmax函数.温度超参 τ \tau τ 用来衡量两个分布 q t q^t qt 和 q s q^s qs 的平滑程度,决定了两个概率分布间的距离, τ \tau τ 越大( τ > 1 ) \tau>1) τ>1) ,就会使得概率分布越平滑(soft), τ \tau τ 越小 ( 0 < τ < 1 ) (0<\tau<1) (0<τ<1) ,越接近0,会使得概率分布越尖锐(sharp)。 τ \tau τ 的大小影响着蒸馏中学生模型学习的难度,而现有工作普遍的方式都是采用固定的温度超参,一般会设定成4。
针对原始蒸馏任务,以最小化任务特定损失和蒸馏损失为目标,对学生进行优化。蒸馏过程的目标可以表述如下:
min θ s t u L ( θ s t u ) = min θ s t u ∑ x ∈ D α 1 L t a s k ( f s ( x ; θ stu ) , y ) + α 2 L k d ( f l ( x ; θ tea ) , f s ( x ; θ stu ) , τ ) . \begin{aligned} \min _{\theta_{s t u}} L\left(\theta_{s t u}\right) & =\min _{\theta_{s t u}} \sum_{x \in D} \alpha_1 L_{t a s k}\left(f^s\left(x ; \theta_{\text {stu }}\right), y\right) \\ & +\alpha_2 L_{k d}\left(f^l\left(x ; \theta_{\text {tea }}\right), f^s\left(x ; \theta_{\text {stu }}\right), \tau\right) .\end{aligned} θstuminL(θstu)=θstuminx∈D∑α1Ltask(fs(x;θstu ),y)+α2Lkd(fl(x;θtea ),fs(x;θstu ),τ).
其中 L t a s k L_{t a s k} Ltask是图像分类任务的正则交叉熵损失, f L ( ⋅ ) f^L(\cdot) fL(⋅) 和 f s ( ⋅ ) f^s(\cdot) fs(⋅)是教师和学生的函数, α 1 \alpha_1 α1和 α 2 \alpha_2 α2是平衡权重。
为了通过动态温度控制学生的学习难度,受GANs的启发,本文提出对抗学习一个动态温度模块 θ temp \theta_{\text {temp }} θtemp ,该模块预测一个适合当前训练的温度值 τ \tau τ 。该模块在与学生相反的方向上进行优化,旨在最大化学生与教师之间的蒸馏损失。与原始蒸馏不同,学生 θ s t u \theta_{s t u} θstu和温度模块 θ temp \theta_{\text {temp }} θtemp 以如下价值函数 L ( θ stu , θ temp ) L\left(\theta_{\text {stu }}, \theta_{\text {temp }}\right) L(θstu ,θtemp ) 进行两人极大极小不等式博弈:
min θ stu max θ temp L ( θ stu , θ temp ) = min θ s t u max temp ∑ x ∈ D α 1 L task ( f s ( x ; θ stu ) , y ) + α 2 L k d ( f t ( x ; θ tea ) , f s ( x ; θ s t u ) , θ temp ) . \begin{aligned} & \min _{\theta_{\text {stu }}} \max _{\theta_{\text {temp }}} L\left(\theta_{\text {stu }}, \theta_{\text {temp }}\right) \\ & =\min _{\theta_{s t u}} \max _{\text {temp }} \sum_{x \in D} \alpha_1 L_{\text {task }}\left(f^s\left(x ; \theta_{\text {stu }}\right), y\right) \\ & +\alpha_2 L_{k d}\left(f^t\left(x ; \theta_{\text {tea }}\right), f^s\left(x ; \theta_{s t u}\right), \theta_{\text {temp }}\right) . \end{aligned} θstu minθtemp maxL(θstu ,θtemp )=θstumintemp maxx∈D∑α1Ltask (fs(x;θstu ),y)+α2Lkd(ft(x;θtea ),fs(x;θstu),θtemp ).
采用交替算法求解方程中的问题,固定一组变量,求解另一组变量:
θ ^ stu = arg min θ stu L ( θ stu , θ ^ temp ) θ ^ temp = arg max θ temp L ( θ ^ stu , θ temp ) \begin{aligned} \hat{\theta}_{\text {stu }} & =\arg \min _{\theta_{\text {stu }}} L\left(\theta_{\text {stu }}, \hat{\theta}_{\text {temp }}\right) \\ \hat{\theta}_{\text {temp }} & =\arg \max _{\theta_{\text {temp }}} L\left(\hat{\theta}_{\text {stu }}, \theta_{\text {temp }}\right)\end{aligned} θ^stu θ^temp =argθstu minL(θstu ,θ^temp )=argθtemp maxL(θ^stu ,θtemp )
通过随机梯度下降( SGD )进行优化,学生 θ s t u \theta_{s t u} θstu和温度模块 θ temp \theta_{\text {temp }} θtemp 参数更新如下:
θ s t u ← θ s t u − μ ∂ L ∂ θ s t u θ t e m p ← θ t e m p + μ ∂ L ∂ θ temp \begin{aligned} \theta_{s t u} & \leftarrow \theta_{s t u}-\mu \frac{\partial L}{\partial \theta_{s t u}} \\ \theta_{t e m p} & \leftarrow \theta_{t e m p}+\mu \frac{\partial L}{\partial \theta_{\text {temp }}}\end{aligned} θstuθtemp←θstu−μ∂θstu∂L←θtemp+μ∂θtemp ∂L
通过一个非参数化的梯度反转层( Gradient Reversal Layer,GRL )来实现上述对抗过程,在softmax层和可学习温度模块之间插入GRL,知识蒸馏课程温度( CTKD )如图所示。
图 1 :知识蒸馏课程温度 图1:知识蒸馏课程温度 图1:知识蒸馏课程温度
( a )引入了一个可学习的温度模块来预测合适的蒸馏温度,使用梯度反转层来反转反向传播过程中温度模块的梯度。( b )遵循先易后难的课程设置,逐步增加参数λ,导致学生的学习难度
保持恒定的学习难度对于一个成长中的学生来说是次优的。受课程学习的启发,本文进一步介绍了一个简单而有效的课程,它通过直接将损失 L L L按 λ \lambda λ大小温度来组织蒸馏任务,即 L → λ L L \rightarrow \lambda L L→λL。因此, θ temp \theta_{\text {temp }} θtemp 将被更新:
θ t e m p ← θ t e m p + μ ∂ ( λ L ) ∂ θ t e m p \theta_{t e m p} \leftarrow \theta_{t e m p}+\mu \frac{\partial(\lambda L)}{\partial \theta_{t e m p}} θtemp←θtemp+μ∂θtemp∂(λL)
将初始λ值设置为0,使得低年级学生可以专注于学习任务而不受任何约束。通过逐步提高λ,随着蒸馏难度的增加,学生学习到更高级的知识。具体而言,遵循课程学习的基本理念,本文提出的课程满足以下两个条件:
可学习温度模块有两个版本,即Global - T和Instance - T。
全局版本只包含一个可学习的参数,对所有实例预测一个值 T pred T_{\text {pred }} Tpred ,如图2 ( a )所示。这种高效的版本不会给蒸馏过程带来额外的计算成本,因为它只涉及一个可学习的参数。
Instance-T.为了获得更好的蒸馏性能,一个全局温度对于所有实例都不够准确。我们进一步探索了实例变量Instance - T,它对所有实例单独预测一个温度值,例如,对于一批128个样本,本文预测128个对应的温度值。本文提出利用概率分布的统计信息来控制自身的平滑性。具体来说,本文引入了一个2层MLP,将两个预测作为输入,输出预测值 T pred T_{\text {pred }} Tpred ,如图2 ( b )所示。在训练过程中,该模块会自动学习原始分布和平滑分布之间的隐含关系。为了保证温度参数的非负性并使其值保持在合适的范围内,我们用下面的公式对预测的 T pred T_{\text {pred }} Tpred 进行标度:
τ = τ init + τ range ( δ ( T pred ) ) \tau=\tau_{\text {init }}+\tau_{\text {range }}\left(\delta\left(T_{\text {pred }}\right)\right) τ=τinit +τrange (δ(Tpred ))
其中 τ init \tau_{\text {init }} τinit 为初始值, τ range \tau_{\text {range }} τrange 为 τ \tau τ的取值范围, δ ( ⋅ ) \delta(\cdot) δ(⋅)为sigmoid函数, T pred T_{\text {pred }} Tpred 为预测值。默认设置 τ init \tau_{\text {init }} τinit 和 τ range \tau_{\text {range }} τrange 为1和20,这样可以包含所有的正常值。
与Global - T相比,Instance - T由于具有更好的表示能力,可以获得更好的精馏性能。