前言
目前提高机器学习算法性能的方法几乎都用多模型ensamble,在计算上非常昂贵且难以部署,尤其是大型神经网络,比如bert。
知识蒸馏就是希望用小型模型得到跟大型复杂模型一样的性能。人们普遍认为,用于训练的目标函数应尽可能地反映任务的真实目标。在训练过程中,往往以最优化训练集的准确率作为训练目标,但真实目标其实应该是最优化模型的泛化能力。显然如果能直接以提升模型的泛化能力为目标进行训练是最好的,但这需要正确的关于泛化能力的信息,而这些信息通常不可用。
那么怎么能获得可以利用的泛化能力的信息呢?
在一般的分类训练任务中,我们以softmax层输出各个类别的概率,然后以跟one-hot lables的交叉熵作为loss function,这个loss function丢失了在其他类别上的概率,只把正确label上的概率值考虑进来。但其实在错误label上的概率分布也是具有价值的。比如狗的样本识别为老虎上的错误概率会比识别为蚂蚁的错误概率大,因为更像,而不像。这在一定程度上反映了该分类器的泛化机制,反映了它“脑袋里的知识”。
知识蒸馏
如果我们使用由大型模型产生的所有类概率作为训练小模型的目标,是否就是直接以“泛化能力”作为目标函数呢?研究证明,这样的方法确实可以让小模型得到不输大模型的性能,而且有时甚至青出于蓝胜于蓝。这种把大模型的“知识”迁移到小模型的方式,我们称之为“蒸馏”(浓缩就是精华)。有人用单层BiLSTM对bert进行蒸馏,效果不输ELMo。(详细可看论文 Distilling Task-Specific Knowledge from BERT into Simple Neural Networks )
这里先定义两个概念。“硬目标(softmax)”:正确标签的交叉熵。“软目标”(soft_softmax):大模型产生的类概率的交叉熵。
soft_softmax公式如下:
可以看到,它跟softmax比起来就是在指数项里多了一个“T”,这个T称为蒸馏温度。为什么要加T呢?假如我们分三类,然后网络最后的输出是[1.0 2.0 3.0],我们可以很容易的计算出,传统的softmax(即T=1)对此进行处理后得到的概率为[0.09 0.24 0.67],而当T=4的时候,得到的概率则为[0.25 0.33 0.42]。可以看出,当T变大的时候输出的概率分布变得平缓了,这就蒸馏温度的作用。这时候得到的概率分布我们称之为“soft target label”。我们在训练小模型的时候需要用到“soft target label”。
在训练小模型时,目标函数为:
其中
为soft target lable,这里T要跟蒸馏复杂模型时的T大小一致,也就是保持同样的蒸馏温度,避免改变“知识”分布。注意:小模型在做预测时蒸馏温度要还原为1,也就是用原始概率分布做预测,因为再预测时希望正确标签与错误标签的概率差距尽量大,与蒸馏时的希望平缓区别开来。
实际上可以这么理解,知识蒸馏是在本来的目标函数上加上了正则项,正则项可以提高模型的泛化能力,把软目标当作正则项就是让小模型的泛化能力尽量接近复杂模型的泛化能力。软目标具有高熵值时,它们为每个训练案例提供比硬目标更多的信息,并且在训练案例之间梯度的变化更小,因此小模型通常可以在比原始繁琐模型少得多的数据上训练并可以使用更高的学习率加快训练过程。
总结一下:
知识蒸馏就是
1.从复杂模型中得到“soft target label”。
2.在训练小模型时同时训练硬目标和软目标。