2023 Curriculum Temperature for Knowledge Distillation

论文地址:https://arxiv.org/abs/2211.16231
代码地址:https://github.com/zhengli97/CTKD

1 研究动机与研究思路

研究动机:大多数现有的蒸馏方法忽略了温度在损失函数中的灵活作用,将其固定为超参数。一般而言,温度控制着两种分布之间的差异,确定蒸馏任务的难易程度。保持一个恒定的温度,即固定的任务难度,在渐进学习阶段通常是次优的。
研究思路:本文提出了一种简单的基于课程的技术,称为知识蒸馏课程温度( CTKD ),它是一个动态温度超参蒸馏新方法。具体来说,遵循由易到难的课程设置,随温度的变化逐渐增加蒸馏损失,以对抗的方式导致蒸馏难度的增加。

2 主要工作

本文的主要工作:

  • 本文提出在学生的训练过程中使用反向梯度对抗学习动态温度超参数,以最大化师生之间的蒸馏损失。
  • 本文引入了简单有效的课程,通过一个动态和可学习的温度参数,从易到难地组织蒸馏任务。

3 方法

3.1 知识蒸馏

传统的两段蒸馏过程通常以预先训练的繁琐的教师网络开始。然后在教师网络的监督下以soft预测或中间表示的形式训练一个紧凑的学生网络。采用带有温度超参的KL Divergence Loss散度损失最小化学生和教师模型的soft输出概率差异,从而在教师模 型和学生模型之间进行蒸馏, 公式如下:
L k d ( q t , q s , τ ) = ∑ i = 1 I τ 2 K L ( σ ( q i t / τ ) , σ ( q i s / τ ) ) L_{k d}\left(q^t, q^s, \tau\right)=\sum_{i=1}^I \tau^2 K L\left(\sigma\left(q_i^t / \tau\right), \sigma\left(q_i^s / \tau\right)\right) Lkd(qt,qs,τ)=i=1Iτ2KL(σ(qit/τ),σ(qis/τ))
其中, q t , q s q^t, q^s qt,qs分别表示教师和学生产生的logit, σ ( ⋅ ) \sigma ( · ) σ()为softmax函数.温度超参 τ \tau τ 用来衡量两个分布 q t q^t qt q s q^s qs 的平滑程度,决定了两个概率分布间的距离, τ \tau τ 越大( τ > 1 ) \tau>1) τ>1) ,就会使得概率分布越平滑(soft), τ \tau τ 越小 ( 0 < τ < 1 ) (0<\tau<1) (0<τ<1) ,越接近0,会使得概率分布越尖锐(sharp)。 τ \tau τ 的大小影响着蒸馏中学生模型学习的难度,而现有工作普遍的方式都是采用固定的温度超参,一般会设定成4。

3.2 对抗性蒸馏

针对原始蒸馏任务,以最小化任务特定损失和蒸馏损失为目标,对学生进行优化。蒸馏过程的目标可以表述如下:
min ⁡ θ s t u L ( θ s t u ) = min ⁡ θ s t u ∑ x ∈ D α 1 L t a s k ( f s ( x ; θ stu  ) , y ) + α 2 L k d ( f l ( x ; θ tea  ) , f s ( x ; θ stu  ) , τ ) . \begin{aligned} \min _{\theta_{s t u}} L\left(\theta_{s t u}\right) & =\min _{\theta_{s t u}} \sum_{x \in D} \alpha_1 L_{t a s k}\left(f^s\left(x ; \theta_{\text {stu }}\right), y\right) \\ & +\alpha_2 L_{k d}\left(f^l\left(x ; \theta_{\text {tea }}\right), f^s\left(x ; \theta_{\text {stu }}\right), \tau\right) .\end{aligned} θstuminL(θstu)=θstuminxDα1Ltask(fs(x;θstu ),y)+α2Lkd(fl(x;θtea ),fs(x;θstu ),τ).
其中 L t a s k L_{t a s k} Ltask是图像分类任务的正则交叉熵损失, f L ( ⋅ ) f^L(\cdot) fL() f s ( ⋅ ) f^s(\cdot) fs()是教师和学生的函数, α 1 \alpha_1 α1 α 2 \alpha_2 α2是平衡权重。
为了通过动态温度控制学生的学习难度,受GANs的启发,本文提出对抗学习一个动态温度模块 θ temp  \theta_{\text {temp }} θtemp ,该模块预测一个适合当前训练的温度值 τ \tau τ 。该模块在与学生相反的方向上进行优化,旨在最大化学生与教师之间的蒸馏损失。与原始蒸馏不同,学生 θ s t u \theta_{s t u} θstu和温度模块 θ temp  \theta_{\text {temp }} θtemp  以如下价值函数 L ( θ stu  , θ temp  ) L\left(\theta_{\text {stu }}, \theta_{\text {temp }}\right) L(θstu ,θtemp ) 进行两人极大极小不等式博弈:
min ⁡ θ stu  max ⁡ θ temp  L ( θ stu  , θ temp  ) = min ⁡ θ s t u max ⁡ temp  ∑ x ∈ D α 1 L task  ( f s ( x ; θ stu  ) , y ) + α 2 L k d ( f t ( x ; θ tea  ) , f s ( x ; θ s t u ) , θ temp  ) . \begin{aligned} & \min _{\theta_{\text {stu }}} \max _{\theta_{\text {temp }}} L\left(\theta_{\text {stu }}, \theta_{\text {temp }}\right) \\ & =\min _{\theta_{s t u}} \max _{\text {temp }} \sum_{x \in D} \alpha_1 L_{\text {task }}\left(f^s\left(x ; \theta_{\text {stu }}\right), y\right) \\ & +\alpha_2 L_{k d}\left(f^t\left(x ; \theta_{\text {tea }}\right), f^s\left(x ; \theta_{s t u}\right), \theta_{\text {temp }}\right) . \end{aligned} θstu minθtemp maxL(θstu ,θtemp )=θstumintemp maxxDα1Ltask (fs(x;θstu ),y)+α2Lkd(ft(x;θtea ),fs(x;θstu),θtemp ).
采用交替算法求解方程中的问题,固定一组变量,求解另一组变量:
θ ^ stu  = arg ⁡ min ⁡ θ stu  L ( θ stu  , θ ^ temp  ) θ ^ temp  = arg ⁡ max ⁡ θ temp  L ( θ ^ stu  , θ temp  ) \begin{aligned} \hat{\theta}_{\text {stu }} & =\arg \min _{\theta_{\text {stu }}} L\left(\theta_{\text {stu }}, \hat{\theta}_{\text {temp }}\right) \\ \hat{\theta}_{\text {temp }} & =\arg \max _{\theta_{\text {temp }}} L\left(\hat{\theta}_{\text {stu }}, \theta_{\text {temp }}\right)\end{aligned} θ^stu θ^temp =argθstu minL(θstu ,θ^temp )=argθtemp maxL(θ^stu ,θtemp )
通过随机梯度下降( SGD )进行优化,学生 θ s t u \theta_{s t u} θstu和温度模块 θ temp  \theta_{\text {temp }} θtemp  参数更新如下:
θ s t u ← θ s t u − μ ∂ L ∂ θ s t u θ t e m p ← θ t e m p + μ ∂ L ∂ θ temp  \begin{aligned} \theta_{s t u} & \leftarrow \theta_{s t u}-\mu \frac{\partial L}{\partial \theta_{s t u}} \\ \theta_{t e m p} & \leftarrow \theta_{t e m p}+\mu \frac{\partial L}{\partial \theta_{\text {temp }}}\end{aligned} θstuθtempθstuμθstuLθtemp+μθtemp L
通过一个非参数化的梯度反转层( Gradient Reversal Layer,GRL )来实现上述对抗过程,在softmax层和可学习温度模块之间插入GRL,知识蒸馏课程温度( CTKD )如图所示。
2023 Curriculum Temperature for Knowledge Distillation_第1张图片

图 1 :知识蒸馏课程温度 图1:知识蒸馏课程温度 1:知识蒸馏课程温度
( a )引入了一个可学习的温度模块来预测合适的蒸馏温度,使用梯度反转层来反转反向传播过程中温度模块的梯度。( b )遵循先易后难的课程设置,逐步增加参数λ,导致学生的学习难度

3.3 课程温度

保持恒定的学习难度对于一个成长中的学生来说是次优的。受课程学习的启发,本文进一步介绍了一个简单而有效的课程,它通过直接将损失 L L L λ \lambda λ大小温度来组织蒸馏任务,即 L → λ L L \rightarrow \lambda L LλL。因此, θ temp  \theta_{\text {temp }} θtemp  将被更新:
θ t e m p ← θ t e m p + μ ∂ ( λ L ) ∂ θ t e m p \theta_{t e m p} \leftarrow \theta_{t e m p}+\mu \frac{\partial(\lambda L)}{\partial \theta_{t e m p}} θtempθtemp+μθtemp(λL)
将初始λ值设置为0,使得低年级学生可以专注于学习任务而不受任何约束。通过逐步提高λ,随着蒸馏难度的增加,学生学习到更高级的知识。具体而言,遵循课程学习的基本理念,本文提出的课程满足以下两个条件:

  • 给定唯一变量 τ \tau τ ,蒸馏损失温度模块(简化为 L k d ( τ ) ) \left.L_{k d}(\tau)\right) Lkd(τ)))逐渐增大,即
    L k d ( τ n + 1 ) ≥ L k d ( τ n ) L_{k d}\left(\tau_{n+1}\right) \geq L_{k d}\left(\tau_n\right) Lkd(τn+1)Lkd(τn)
  • λ值增大,即
    λ n + 1 ≥ λ n \lambda_{n+1} \geq \lambda_n λn+1λn
    式中:n表示第n步训练。当在 E n En En epoch处训练时以如下的余弦调度逐步增加λ:
    λ n = λ min ⁡ + 1 2 ( λ max ⁡ − λ min ⁡ ) ( 1 + cos ⁡ ( ( 1 + min ⁡ ( E n , E loops  ) E loops  ) π ) \begin{aligned} \lambda_n & =\lambda_{\min } \\ & +\frac{1}{2}\left(\lambda_{\max }-\lambda_{\min }\right)\left(1+\cos \left(\left(1+\frac{\min \left(E_n, E_{\text {loops }}\right)}{E_{\text {loops }}}\right) \pi\right)\right.\end{aligned} λn=λmin+21(λmaxλmin)(1+cos((1+Eloops min(En,Eloops ))π)
    式中: λ max ⁡ \lambda_{\max } λmax λ min ⁡ \lambda_{\min } λmin λ \lambda λ的取值范围。 E l o o p s E_ {loops} Eloops是难度尺度 λ \lambda λ逐渐变化的超参数。本文默认设置 λ max ⁡ \lambda_{\max } λmax, λ min ⁡ \lambda_{\min } λmin E l o o p s E_ {loops} Eloops分别为1、0和10。该过程表明参数 λ \lambda λ在10个训练周期内从0增加到1,并一直保持1直到结束。详细的消融研究见表6和表8。

3.4可学习温度模块

可学习温度模块有两个版本,即Global - T和Instance - T。
2023 Curriculum Temperature for Knowledge Distillation_第2张图片
全局版本只包含一个可学习的参数,对所有实例预测一个值 T pred  T_{\text {pred }} Tpred ,如图2 ( a )所示。这种高效的版本不会给蒸馏过程带来额外的计算成本,因为它只涉及一个可学习的参数。
Instance-T.为了获得更好的蒸馏性能,一个全局温度对于所有实例都不够准确。我们进一步探索了实例变量Instance - T,它对所有实例单独预测一个温度值,例如,对于一批128个样本,本文预测128个对应的温度值。本文提出利用概率分布的统计信息来控制自身的平滑性。具体来说,本文引入了一个2层MLP,将两个预测作为输入,输出预测值 T pred  T_{\text {pred }} Tpred ,如图2 ( b )所示。在训练过程中,该模块会自动学习原始分布和平滑分布之间的隐含关系。为了保证温度参数的非负性并使其值保持在合适的范围内,我们用下面的公式对预测的 T pred  T_{\text {pred }} Tpred 进行标度:
τ = τ init  + τ range  ( δ ( T pred  ) ) \tau=\tau_{\text {init }}+\tau_{\text {range }}\left(\delta\left(T_{\text {pred }}\right)\right) τ=τinit +τrange (δ(Tpred ))
其中 τ init  \tau_{\text {init }} τinit  为初始值, τ range  \tau_{\text {range }} τrange  τ \tau τ的取值范围, δ ( ⋅ ) \delta(\cdot) δ()为sigmoid函数, T pred  T_{\text {pred }} Tpred 为预测值。默认设置 τ init  \tau_{\text {init }} τinit  τ range  \tau_{\text {range }} τrange 为1和20,这样可以包含所有的正常值。
与Global - T相比,Instance - T由于具有更好的表示能力,可以获得更好的精馏性能。

4 算法伪代码

2023 Curriculum Temperature for Knowledge Distillation_第3张图片

你可能感兴趣的:(文献阅读,人工智能,深度学习)