知识蒸馏(Knowledge Distillation)

知识蒸馏(Knowledge Distillation)_第1张图片

知识蒸馏(Knowledge Distillation)_第2张图片 论文:[1503.02531] Distilling the Knowledge in a Neural Network (arxiv.org)

知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方式,由于其简单,有效,并且已经在工业界被广泛应用。

知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:

        ①原始模型训练: 训练"Teacher模型", 简称为Net-T,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对"Teacher模型"不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入X, 其都能输出Y,其中Y经过softmax的映射,输出值对应相应类别的概率值。

        ②精简模型训练: 训练"Student模型", 简称为Net-S,它是参数量较小、模型结构相对简单的单模型。同样的,对于输入X,其都能输出Y,Y经过softmax映射后同样能输出对应相应类别的概率值。在本论文中,作者将问题限定在分类问题下,或者其他本质上属于分类问题的问题,该类问题的共同点是模型最后会有一个softmax层,其输出值对应了相应类别的概率值。

现实中,由于我们不可能收集到某问题的所有数据来作为训练数据,并且新数据总是在源源不断的产生,因此我们只能退而求其次,训练目标变成在已有的训练数据集上建模输入和输出之间的关系。由于训练数据集是对真实数据分布情况的采样,训练数据集上的最优解往往会多少偏离真正的最优解。

而在知识蒸馏时,由于我们已经有了一个泛化能力较强的Net-T,我们在利用Net-T来蒸馏训练Net-S时,可以直接让Net-S去学习Net-T的泛化能力。一个很直白且高效的迁移泛化能力的方法就是使用softmax层输出的类别的概率来作为“soft target”。

        ①传统training过程(hard targets): 对ground truth求极大似然

        ②KD的training过程(soft targets): 用large model的class probabilities作为soft targets

知识蒸馏(Knowledge Distillation)_第3张图片

 例子:

在MNIST手写数字识别任务中

假设某个输入的“2”更加形似"3",softmax的输出值中"3"对应的概率为0.1,而其他负标签对应的值都很小,而另一个"2"更加形似"7","7"对应的概率为0.1。这两个"2"对应的hard target的值是相同的,但是它们的soft target却是不同的,由此我们可见soft target蕴含着比hard target多的信息。并且soft target分布的熵相对高时,其soft target蕴含的知识就更丰富。

知识蒸馏(Knowledge Distillation)_第4张图片

 两个”2“的hard target相同而soft target不同。

这就解释了为什么通过蒸馏的方法训练出的Net-S相比使用完全相同的模型结构和训练数据只使用hard target的训练方法得到的模型,拥有更好的泛化能力。

温度T

把其他类别的可能性放大,把他们的相对大小充分暴露出来,让学生网络更加强烈地知道这些非类别的信息。当T=1时,与之前没有变化;当T越大,曲线的波峰就会越来越平滑

知识蒸馏(Knowledge Distillation)_第5张图片

 知识蒸馏(Knowledge Distillation)_第6张图片

 知识蒸馏的过程:

第一步:有一个已经训练好的Teacher model,把很多数据喂给Teacher model,再把数据喂给(未训练/半成品)Student model,两个都是在T=t时经过Softmax,然后计算这两个的损失函数值,让它们两个越接近越好,学生在模拟老师的预测结果。

第二步:Student model在T=1情况下经过softmax操作,把预测结果hard prediction和真实数据的结果hard label进行求损失值,希望它们两个越接近越好。

总结:Student model(T=t)与Teacher model(T=t)的预测结果越来越接近;Student model(T=1)的预测结果与数据结果(标准答案)越来越接近。

Loss = k1*distillation Loss+k2*student Loss。(加权求和)

知识蒸馏(Knowledge Distillation)_第7张图片

知识蒸馏(Knowledge Distillation)_第8张图片

 ​​​​知识蒸馏(Knowledge Distillation)_第9张图片

在使用Student model时只需要输入数据就行,不需要T,因为模型的参数已经训练完成了,最后只需要经过基础softmax操作得到最终结果。知识蒸馏(Knowledge Distillation)_第10张图片

 实验结果:

使用MNIST数据集训练Teacher model,把MNIST数据集中去除”3“相关的所有数据集来训练Student model,实验结果证明,经过知识蒸馏后,没有学习过”3“的Student model可以识别出”3“。

Soft targets可以仅仅使用3%的训练集来训练并达到近似Teacher model的效果。

知识蒸馏的应用场景:

①模型压缩

②优化训练,防止过拟合

③无限大、无监督数据集的数据挖掘

④少样本、零样本学习

你可能感兴趣的:(深度学习,人工智能)