这篇文章非常有意思,本文文字部分较多,主要记录了个人对于文章的一些思考
《On the Efficacy of Knowledge Distillation》于2019年发表在ICCV上。通过实验,作者发现了一个“怪相”,准确率越高的模型并不一定就是好的teacher模型,对于同一个student模型而言,teacher模型越大,teacher模型的准确率越高,知识蒸馏得到的student模型性能却越差。作者认为是student模型与teacher模型的容量相差太大,导致student模型无法模拟teacher模型,从而出现上述怪相。为了解决上述“怪相”,作者提出在训练teacher模型时进行early stop,即teacher模型训练一定轮数后停止训练,此时teacher模型可能还未较好拟合训练集。
按我的理解,知识蒸馏的实质是想让student模型模拟teacher模型的输出,从而在student模型的参数空间中找到一个解,这个解是teacher模型找到的解的近似解,当student模型与teacher模型容量差距太大时,将导致student模型无法找到近似解,但是由此又引发一个问题,该如何防止student模型与teacher模型的容量差距太大呢?论文并没有解释。作者从另外一个角度提出了解决方案,通过在训练时early stop teacher模型,让teacher模型找到的解尽可能简单,从而让student模型尽可能找到近似解。
2015年,hinton在论文《Distilling the Knowledge in a Neural Network》中提出了知识蒸馏,本文中的“知识蒸馏”均指该方法。
上图展示了当student模型为resnet18,teacher模型为resnet18、34、50时,在IMageNet数据上,利用知识蒸馏得到的student模型的准确率,第一行表示不使用知识蒸馏训练的resnet18的准确率,从上图可以看出两个问题:
第二个问题为本人提出,问题二驳斥了目前网上对于知识蒸馏为什么有效的一个解释,即soft label相比于hard label,可以提供更多类与类之间相似性的信息,这类信息将有助于student模型区分类。但是从上图数据来看,当teacher模型为resnet18时,给出的soft label也可以反映类与类之间的相似性,但是student模型的准确率却并没有更高,因此个人不是很认同这个观点。
针对于问题二,个人的理解是——若student模型与teacher模型的结构不同,teacher模型性能优于student模型,此时知识蒸馏可能可以让student模型在参数空间中找到一个与teacher模型解近似的解,这个近似解通常不如teacher模型的解,但可能可以让student模型与teacher模型性能近似,从而提高student模型的性能。若student模型与teacher模型的结构相同,teacher模型的性能与student模型性能基本一致,此时student找到的近似解并不一定就能提高student模型的准确率。
针对问题一,作者提出了三种假设
假设一
高温可以防止teacher模型的soft label与hard label近似,防止soft label给出的信息近似于hard label,但是当温度为20时,在ImageNet数据集上依然会出现“怪相”,如下图:
因此作者否定了这个假设
假设二、假设三
当student模型为ResNet18,teacher模型分别为ResNet18、34、50时,在ImageNet上运用知识蒸馏的结果如上,KD(Train)表示训练集上的KD loss值,CE(Train)表示训练集上交叉熵loss的值,有(ES KD)符号标记的数据可以暂时不看。
如果student模型可以模拟teacher模型,那么在训练集上的KD loss应该趋近于0,依据上图(无ES KD部分),我们可以得知student模型无法模拟teacher模型(KD loss大于1),并且模型越大,KD loss也越大,这说明student模型越来越难以模拟teacher模型。由此推翻了假设二,印证了假设三。
可以看到ResNet18(无ES KD)一行的KD loss值非常大,这并不能说明问题,由于student模型和teacher模型共享参数空间,teacher模型找到的解存在于student模型的参数空间中,这个解可以使KD loss取值接近于0,这里KD loss这么大,很可能归因于优化算法不够智能,或是初始化参数不一致,导致无法找到teacher模型的解。
上一节我们通过实验证明了假设三,具体而言,即teacher模型越大,student模型越难在参数空间中找到一个不错的近似解(体现在KD loss会随着teacher模型容量增大而增大),导致student模型性能越来越糟糕。
依据上述假设,作者给出了三个可能的解决方案
解决方案一
由于student模型难以找到近似解,那就是用知识蒸馏做一个pretrain,接着用交叉熵损失函数,以求找到一个尽可能好的解,实验结果为下图(含有ES KD符号)
可以看到,ES KD的性能优于知识蒸馏,但是仍然会出现“怪相”。
解决方案二
看上图第一、三行最后一列数据,基本没有差别,这里其实有一个核心的问题,要选择怎样的middle模型,才能即让middle模型找到teacher模型解近似的解,又让student模型(small模型)找到与middle模型解近似的解,这似乎把问题变得更加复杂,解决方案二是不能work的。
解决方案三
通过early stop teacher模型的训练,让teacher模型找到的解尽可能简单,从而方便student模型找到对应的近似解。
x轴表示teacher模型训练的epoch数目,可以看到,在一定范围内,当teacher模型越大,student模型的错误率越小,比如第一幅图中的WRN-4、6、8,当epoch太小时,此时teacher模型找到的解可能很糟糕,这导致student模型性能较差,当epoch太大时,teacher模型找到的解太复杂,student模型难以找到近似解,导致student模型性能较差。
如果您想了解更多有关深度学习、机器学习基础知识,或是java开发、大数据相关的知识,欢迎关注我们的公众号,我将在公众号上不定期更新深度学习、机器学习相关的基础知识,分享深度学习中有趣文章的阅读笔记。