深度学习 论文解读——知识蒸馏

深度学习 论文解读——知识蒸馏_第1张图片

1.简介

在大规模的的机器学习任务中,我们经常需要使用集成学习来训练一些相似的模型,而这会消耗大量的计算资源。并且在部署为实际应用时,也会由于模型/网络过于复杂而难以部署。针对此问题,Hinton提出了一种压缩网络的有效方法——知识蒸馏(Knowledge Distillation)[1]。并且讨论了知识蒸馏的其他优势。

2.蒸馏方法

当我们已经有一个训练好的,较为笨重的模型(cumbersome model,通常也称为教师模型)时,我们可以采用一种新的训练方式——蒸馏,来将原模型中已经学习到的知识转移到新的,轻量级的模型(通常也称为学生模型)中。对于分类任务而言,模型本质上是在将一个输入的矩阵映射为一个输出矩阵,输出矩阵中的最大值对应的就是最终预测的分类。除了最大值给我们传递的预测信息之外,模型对于其他类的预测值同样也包含了很关键的信息。例如,当输入为一辆汽车的图片时,卡车类在预测结果中的值将会远大于萝卜类的预测值。利用好这一点,我们就可以实现知识的转移,即将笨重模型中包含的知识转移到轻量级的模型中。

我们将除目标类外,别的类的预测值成为soft targets(当教师模型为一个集成模型时,soft targets可以取为多个模型的算数平均,或几何平均值)。由于soft targets中包含了大量的信息,所以在训练学生模型时,使用少量的数据就可以完成训练。

在学生模型的训练中,一个寻常的想法是最小化学生模型的预测logits与教师模型的logits间的均方误差。对此,文章中引入了温度的概念。

在常用的softmax函数中,我们没有使用到这个温度T(换言之,T=1)。这里引入的广义的softmax函数,能够生成与原先不同的输出值。容易看出,当T值很高时,预测值的分布将更加靠近,因此也包含了更多的信息。

也就是说,我们先将T升高来得到一个更"软"的分布(softer probability distribution),但是在预测阶段,重新将温度值调低来获取模型中的知识,这也是称这个方法为蒸馏的原因。

注意到我们改变温度值时,实质上改变了数据的分布,因此我们最后还需要乘上T方。

3.实验验证

在MNIST数据集上使用蒸馏方法时,测试错误从146个降低到了74个(温度值为20)。这说明soft targets确实传递了更多的信息。

当训练学生模型时,如果除去训练集中所有的’3’,模型总共的预测错误是206,其中133个是数字3,而数字3的总数为1010,这表明我们用一个有缺失类的训练集,也能很好的学到别的类的信息。这一点也在训练集中只包含8和9的实验中得到了很好的证明。

为了证明是否能使用更少的训练集来取得相近的预测效果,文中采用了3%的数据来进行对比:

效果对比

可以看到,使用soft targets训练得到的学生模型表现远好于baseline的训练效果,其效果仅比使用大量数据进行训练、使用early stopping的baseline差了1.9%。值得一提的是,使用soft targets进行训练,最终的模型是近乎收敛到最优值而不需要采用early stopping。

4.讨论

对于大型神经网络,完整的进行训练也是不可行的,但是已经证明可以使用多个专家模型进行加速。每个专家网都能够区分一部分高度混淆的的类。但目前还不明确如何将专家模型的知识全部转移到单个的大模型中。

References:

[1]. Hinton, G., O. Vinyals and J. Dean, Distilling the Knowledge in a Neural Network. 2015.

你可能感兴趣的:(数据科学,python,算法)