知识蒸馏经典论文阅读

这篇Hinton大佬的 Distilling the Knowledge in a Neural Network,是知识蒸馏领域的开山之作,第一次引入了知识蒸馏的概念。

整体的论文研究动机如下:

  • 模型在工业落地对实时性和计算资源有要求高,尤其是像移动终端,需要在尽可能小的部署代价下快速得到准确预测结果
  • 为了提升模型准确率,往往采用集成学习的思想,用一组模型共同决策,而这更增加了模型的体量,所以Hinton的目标是将一组模型中的知识提取为简单的模型
  • 训练模型会将概率分配给所有的错误答案,就算他们概率很小,但是也比其他的答案概率大很多
  • 其他预测类别的概率太小,基本接近于0,交叉熵很小,在交叉熵验证时产生很小的影响

举个例子,我们训练了一个CNN的网络,给定一张宝马车的图片,CNN网络softmax输出如下:

label 概率
宝马车 0.90
垃圾车 0.09
胡萝卜 0.01

我们可以看到,CNN模型经过训练后,很“自信”的认为图片是宝马车,而图片被认成垃圾车的概率虽然很小,但仍是胡萝卜的9倍。
此时的模型,我们是否可以认为,它不仅学到了正确识别宝马车,还学会区分其他不同种类,那在压缩模型的时候,如果小模型能够学到这种潜在的知识,而不是只会背答案,模型的泛化能力将会提升很多,这便是整个文章的核心。

论文要点

神经网络通常通过使用转换logit的“ softmax”输出层来产生label概率, Teacher模型的输出,经过softmax层,被指数e拉大了各个label的距离,最终输出结果类似于one-hot向量,不利于student模型的学习。

所以Hinton在基础上引入蒸馏温度 T T T 的概念,通过计算 z i z_{i} zi, 将每个label转换为概率 q i q_{i} qi.
q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_{i} = \frac{exp(z_{i}/T)}{\sum_{j}^{}exp(z_{j}/T)} qi=jexp(zj/T)exp(zi/T)

蒸馏温度T一般被设置成1,即正常的softmax,T越大输出越软、分布越缓和;
T越小越容易放大错误分类的概率,引入不必要的噪声。

∂ C ∂ z i = 1 T ( q i − p i ) = 1 T ( e z i / T ∑ j e z j / T − e v i / T ∑ j e v j / T ) \frac{\partial C}{\partial z_{i}} = \frac{1}{T}(q_{i} - p_{i}) = \frac{1}{T}(\frac{e^{z_{i}/T}}{\sum_{j}^{}e^{z_{j}/T}} - \frac{e^{v_{i}/T}}{\sum_{j}^{}e^{v_{j}/T}}) ziC=T1(qipi)=T1(jezj/Tezi/Tjevj/Tevi/T)
利用一阶泰勒展开,那么 e z i / T ≈ 1 + z i / T e^{z_{i}/T} \approx 1 + z_{i}/T ezi/T1+zi/T,那么上述公式简化为:
∂ C ∂ z i ≈ 1 T ( 1 + z i / T N + ∑ j z j / T − 1 + v i / T N + ∑ j v j / T ) \frac{\partial C}{\partial z_{i}} \approx \frac{1}{T}(\frac{1 + z_{i}/T}{N + \sum_{j}^{}{z_{j}/T}} - \frac{1 + v_{i}/T}{N + \sum_{j}^{}{v_{j}/T}}) ziCT1(N+jzj/T1+zi/TN+jvj/T1+vi/T)
如果我们假定logis在每个迁移样例中是0均值的,那么 ∑ j z j = ∑ j v j = 0 \sum_{j}^{}{z_{j}} = \sum_{j}^{}{v_{j}} = 0 jzj=jvj=0,那么上述公式简化为:
∂ C ∂ z i ≈ 1 N T 2 ( z i − v i ) \frac{\partial C}{\partial z_{i}} \approx \frac{1}{NT^2}(z_{i} - v_{i}) ziCNT21(zivi)

蒸馏以后,对比之前的softmax,梯度相当于乘了 1 T 2 \frac{1}{T^2} T21,因此 L s o f t L_{soft} Lsoft需要乘以 T 2 T^2 T2 L h a r d L_{hard} Lhard一个数量级上。

论文框架

通过定义一定的loss规则,将teacher模型(Big Model)中的知识转移到student模型(Small Model)上,论文算法的整体框架图如下:
知识蒸馏经典论文阅读_第1张图片

L o s s = λ ∗ L s o f t + ( 1 − λ ) ∗ L h a r d Loss = \lambda * L_{soft} + (1-\lambda)*L_{hard} Loss=λLsoft+(1λ)Lhard

λ \lambda λ越大,说明模型越依赖teacher网络的知识;
L s o f t L_{soft} Lsoft衡量Student模型从Teacher模型中学到知识的能力;
L h a r d L_{hard} Lhard衡量Student模型从真实标签中学到做对答案的能力。

我们之前说到其他预测类别的概率太小,交叉熵计算结果很小,提供很少的信息。于是作者提出来soft target.

soft target

利用各个预测分布的算术均值或几何均值作为软目标,这样soft target有更高的熵,训练时能提供更多的信息;在梯度下降中提供更低的偏差,方便使用更少的数据、更大的学习率。

soft target为 teacher模型最后一层隐藏层 H t H_{t} Ht、student模型最后一层隐藏层 H s H_{s} Hs,经过蒸馏温度 T T T softmax后的loss.

hard target

hard target为 Student模型最后一层隐藏层 H s H_{s} Hs 经过softmax,同真实标签 Y Y Y的loss。

实验

数据集:MNIST

teacher Model:训练一个有两层具有1200个单元的隐藏层的大型网络(使用dropout和weight-constraints作为正则)

student Model:具有两层800个单元隐藏层没有正则的网络

  • 如果小型网络正则化通过增加一个额外任务,匹配由大型网络生成的软目标实现,精度会上升;
  • 软目标可以将大量的知识转移到提取的模型中,包括从翻译后的训练数据中学习到的如何泛化的知识,即使转换集没有任何转换

总结

这篇文章算是知识蒸馏的入门,核心就是通过蒸馏温度 T T T将大模型对所有label的预测概率尽可能的迁移到小模型中,可以配合哈工大和讯飞实验室的知识蒸馏框架 TextBrewer食用。

你可能感兴趣的:(知识蒸馏经典论文阅读)