深度学习论文笔记(rethinking knowledge distillation)——On the Efficacy of Knowledge Distillation

文章目录

  • 前言
  • 疑问:高准确率的大模型一定就是好teacher吗?
    • 问题二
    • 问题一
  • 可能的解决方案

这篇文章非常有意思,本文文字部分较多,主要记录了个人对于文章的一些思考

前言

《On the Efficacy of Knowledge Distillation》于2019年发表在ICCV上。通过实验,作者发现了一个“怪相”,准确率越高的模型并不一定就是好的teacher模型,对于同一个student模型而言,teacher模型越大,teacher模型的准确率越高,知识蒸馏得到的student模型性能却越差。作者认为是student模型与teacher模型的容量相差太大,导致student模型无法模拟teacher模型,从而出现上述怪相。为了解决上述“怪相”,作者提出在训练teacher模型时进行early stop,即teacher模型训练一定轮数后停止训练,此时teacher模型可能还未较好拟合训练集。

按我的理解,知识蒸馏的实质是想让student模型模拟teacher模型的输出,从而在student模型的参数空间中找到一个解,这个解是teacher模型找到的解的近似解,当student模型与teacher模型容量差距太大时,将导致student模型无法找到近似解,但是由此又引发一个问题,该如何防止student模型与teacher模型的容量差距太大呢?论文并没有解释。作者从另外一个角度提出了解决方案,通过在训练时early stop teacher模型,让teacher模型找到的解尽可能简单,从而让student模型尽可能找到近似解。

2015年,hinton在论文《Distilling the Knowledge in a Neural Network》中提出了知识蒸馏,本文中的“知识蒸馏”均指该方法。


疑问:高准确率的大模型一定就是好teacher吗?

深度学习论文笔记(rethinking knowledge distillation)——On the Efficacy of Knowledge Distillation_第1张图片
上图展示了当student模型为resnet18,teacher模型为resnet18、34、50时,在IMageNet数据上,利用知识蒸馏得到的student模型的准确率,第一行表示不使用知识蒸馏训练的resnet18的准确率,从上图可以看出两个问题:

  • teacher模型越大、准确率越高,student模型的准确率却越低。
  • 模型结构相同,使用了知识蒸馏的resnet18准确率不如未使用知识蒸馏的resnet18

问题二

第二个问题为本人提出,问题二驳斥了目前网上对于知识蒸馏为什么有效的一个解释,即soft label相比于hard label,可以提供更多类与类之间相似性的信息,这类信息将有助于student模型区分类。但是从上图数据来看,当teacher模型为resnet18时,给出的soft label也可以反映类与类之间的相似性,但是student模型的准确率却并没有更高,因此个人不是很认同这个观点。

针对于问题二,个人的理解是——若student模型与teacher模型的结构不同,teacher模型性能优于student模型,此时知识蒸馏可能可以让student模型在参数空间中找到一个与teacher模型解近似的解,这个近似解通常不如teacher模型的解,但可能可以让student模型与teacher模型性能近似,从而提高student模型的性能。若student模型与teacher模型的结构相同,teacher模型的性能与student模型性能基本一致,此时student找到的近似解并不一定就能提高student模型的准确率。


问题一

针对问题一,作者提出了三种假设

  • teacher模型越大,给出的soft label越接近于hard label,给出的信息越来越近似于hard label
  • student可以模拟teacher,但这不能导致student的泛化性能提升,即知识蒸馏是无效的
  • student无法模拟teacher

假设一
高温可以防止teacher模型的soft label与hard label近似,防止soft label给出的信息近似于hard label,但是当温度为20时,在ImageNet数据集上依然会出现“怪相”,如下图:
深度学习论文笔记(rethinking knowledge distillation)——On the Efficacy of Knowledge Distillation_第2张图片
因此作者否定了这个假设


假设二、假设三
深度学习论文笔记(rethinking knowledge distillation)——On the Efficacy of Knowledge Distillation_第3张图片
当student模型为ResNet18,teacher模型分别为ResNet18、34、50时,在ImageNet上运用知识蒸馏的结果如上,KD(Train)表示训练集上的KD loss值,CE(Train)表示训练集上交叉熵loss的值,有(ES KD)符号标记的数据可以暂时不看。

如果student模型可以模拟teacher模型,那么在训练集上的KD loss应该趋近于0,依据上图(无ES KD部分),我们可以得知student模型无法模拟teacher模型(KD loss大于1),并且模型越大,KD loss也越大,这说明student模型越来越难以模拟teacher模型。由此推翻了假设二,印证了假设三

可以看到ResNet18(无ES KD)一行的KD loss值非常大,这并不能说明问题,由于student模型和teacher模型共享参数空间,teacher模型找到的解存在于student模型的参数空间中,这个解可以使KD loss取值接近于0,这里KD loss这么大,很可能归因于优化算法不够智能,或是初始化参数不一致,导致无法找到teacher模型的解。


可能的解决方案

上一节我们通过实验证明了假设三,具体而言,即teacher模型越大,student模型越难在参数空间中找到一个不错的近似解(体现在KD loss会随着teacher模型容量增大而增大),导致student模型性能越来越糟糕。

依据上述假设,作者给出了三个可能的解决方案

  • 初期使用交叉熵+KD loss作为损失函数,训练一段时间后,只使用交叉熵损失函数
  • 使用Sequential knowledge distillation,即选择一个容量位于student模型与teacher模型之间的middle模型,先将teacher的知识蒸馏到middle模型,在将middle模型的知识蒸馏到student模型
  • 对teacher模型使用early stop,即teacher模型训练一定epoch后停止训练,接着进行蒸馏

解决方案一
由于student模型难以找到近似解,那就是用知识蒸馏做一个pretrain,接着用交叉熵损失函数,以求找到一个尽可能好的解,实验结果为下图(含有ES KD符号)
深度学习论文笔记(rethinking knowledge distillation)——On the Efficacy of Knowledge Distillation_第4张图片
可以看到,ES KD的性能优于知识蒸馏,但是仍然会出现“怪相”。


解决方案二
深度学习论文笔记(rethinking knowledge distillation)——On the Efficacy of Knowledge Distillation_第5张图片
看上图第一、三行最后一列数据,基本没有差别,这里其实有一个核心的问题,要选择怎样的middle模型,才能即让middle模型找到teacher模型解近似的解,又让student模型(small模型)找到与middle模型解近似的解,这似乎把问题变得更加复杂,解决方案二是不能work的。


解决方案三

通过early stop teacher模型的训练,让teacher模型找到的解尽可能简单,从而方便student模型找到对应的近似解。

在CIFAR10数据集上使用上述策略,结果如下:
深度学习论文笔记(rethinking knowledge distillation)——On the Efficacy of Knowledge Distillation_第6张图片

x轴表示teacher模型训练的epoch数目,可以看到,在一定范围内,当teacher模型越大,student模型的错误率越小,比如第一幅图中的WRN-4、6、8,当epoch太小时,此时teacher模型找到的解可能很糟糕,这导致student模型性能较差,当epoch太大时,teacher模型找到的解太复杂,student模型难以找到近似解,导致student模型性能较差。


如果您想了解更多有关深度学习、机器学习基础知识,或是java开发、大数据相关的知识,欢迎关注我们的公众号,我将在公众号上不定期更新深度学习、机器学习相关的基础知识,分享深度学习中有趣文章的阅读笔记。

在这里插入图片描述

你可能感兴趣的:(深度学习)