这几天为了做毕设将网络上关于知识蒸馏的文章上进行了归纳
近年来,神经模型在几乎所有领域都取得了成功,包括极端复杂的问题。然而,这些模型体积巨大,有数百万(甚至数十亿)个参数,这些模型中的大多数在移动电话或嵌入式设备上运行的计算成本太过昂贵,因此不能部署在边缘设备上。
知识蒸馏指的是模型压缩的思想,通过一步一步地使用一个较大的已经训练好的网络去教导一个较小的网络确切地去做什么。“软标签”指的是大网络在每一层卷积后输出的feature map。然后,通过尝试复制大网络在每一层的输出(不仅仅是最终的损失),小网络被训练以学习大网络的准确行为。
显然,模型越复杂,理论搜索空间越大。但是,如果我们假设较小的网络也能实现相同(甚至相似)的收敛,那么教师网络的收敛空间应该与学生网络的解空间重叠。
不幸的是,仅凭这一点并不能保证学生网络收敛在同一点。学生网络的收敛点可能与教师网络有很大的不同。但是,如果引导学生网络复制教师网络的行为(教师网络已经在更大的解空间中进行了搜索),则其预期收敛空间会与原有的教师网络收敛空间重叠。
1.训练教师网络:首先使用完整数据集分别对高度复杂的教师网络进行训练。这个步骤需要高计算性能,因此只能在高性能gpu上完成。
2.构建对应关系:在设计学生网络时,需要建立学生网络的中间输出与教师网络的对应关系。这种对应关系可以直接将教师网络中某一层的输出信息传递给学生网络,或者在传递给学生网络之前进行一些数据增强。
3.通过教师网络前向传播:教师网络前向传播数据以获得所有中间输出,然后对其应用数据增强(如果有的话)。
4.通过学生网络反向传播:现在利用教师网络的输出和学生网络中反向传播误差的对应关系,使学生网络能够学会复制教师网络的行为。
“很多昆虫在幼虫形态的时候是最擅长从环境中吸取能量和养分的,而当他们成长为成虫的时候则需要擅长完全不同能力比如迁移和繁殖。”在2014年Hinton发表的知识蒸馏的论文中用了这样一个很形象的比喻来说明知识蒸馏的目的。在大型的机器学习任务中,我们也用两个不同的阶段 training stage 和 deployment stage 来表达两种不同的需求。training stage(训练阶段)可以利用大量的计算资源不需要实时响应,利用大量的数据进行训练。但是在deployment stage (部署阶段)则会有很多限制,比如计算资源,计算速度要求等。知识蒸馏就是为了满足这种需求而设计的一种模型压缩的方法。
知识蒸馏的概念最早是在2006年由Bulica提出的,在2014年Hinton对知识蒸馏做了归纳和发展。知识蒸馏的主要思想是训练一个小的网络模型来模仿一个预先训练好的大型网络或者集成的网络。这种训练模式又被称为 "teacher-student",大型的网络是“老师”,小型的网络是“学生”。
在知识蒸馏中,老师将知识传授给学生的方法是:在训练学生的过程中最小化一个以老师预测结果的概率分布为目标的损失函数。老师预测的概率分布就是老师模型的最后的softmax函数层的输出,然而,在很多情况下传统的softmax层的输出,正确的分类的概率值非常大,而其他分类的概率值几乎接近于0。因此,这样并不会比原始的数据集提供更多有用的信息,没有利用到老师强大的泛化性能,比如,训练MNIST任务中数字‘3’相对于数字‘5’与数字‘8’的关系更加紧密。为了解决这个问题,Hinton在2015年发表的论文中提出了‘softmax temperature’的概念,对softmax函数做了改进:
这里的T就是指 temperature 参数。当T等于1 时就是标准的softmax函数。当T增大时,softmax输出的概率分布就会变得更加 soft(平滑),这样就可以利用到老师模型的更多信息(老师觉得哪些类别更接近于要预测的类别)。Hinton将这样的蕴含在老师模型中的信息称之为 "dark knowledge",蒸馏的方法就是要将这些 "dark knowledge" 传给学生模型。在训练学生的时候,学生的softmax函数使用与老师的相同的T,损失函数以老师输出的软标签为目标。这样的损失函数我们称为"distillation loss"。
在Hinton的论文中,还发现了在训练过程加上正确的数据标签(hard label)会使效果更好。具体方法是,在计算distillation loss的同时,我利用hard label 把标准的损失(T=1)也计算出来,这个损失我们称之为 "student loss"。将两种 loss 整合的模型的具体结构如下图所示:
Hinton的论文中使用的T的范围为1到20,他们通过实验发现,当学生模型相对于老师模型非常小的时候,T的值相对小一点效果更好。这样的结果直观的理解就是,如果增加T的值,软标签的分布蕴含的信息越多,导致一个小的模型无法"捕捉"所有信息,但是这也只是一种假设,还没有明确的方法来衡量一个网络“捕捉”信息的能力。
另外需要注意一点的是:当蒸馏网的两个隐藏层中每一层都有300个或更多的单位时,所有T≥8的结果都相当相似,但当降低到每层30个单位时,T设置在2.5到4之间的效果要明显比T≥4要好。该现象可能说明将概率设置的过于soften可能会导致一些问题,尤其是在拟合能力较差的网络中。
在hinton的PPT中可以看到硬目标和软目标的例子(链接),关于为何要使用soft targets可查看https://www.zhihu.com/question/50519680/answer/136406661。
具体蒸馏结构如下图所示:
这里λ是hard target与soft target的权重
假设这里选取的 T = 10;
Teacher 模型:
( a ) 求Softmax(T=10)的输出,生成 “Soft targets”
Student 模型:
( a ) 对 Softmax(T = 10)的输出与Teacher 模型的Softmax(T = 10)的输出求 Loss1
( b ) 对 Softmax(T = 1)的输出与原始label 求 Loss2
( c ) Loss = Loss1 + (1/T^2)Loss2
Hinton的论文中做了三个实验,前两个是MNIST和语音识别,在这两个实验中通过知识蒸馏得到的学生模型都达到了与老师模型相近的效果,相对于直接在原始数据集上训练的相同的模型在准确率上都有很大的提高。下面主要讲述第三个比较创新的实验:将知识蒸馏应用在训练集成模型中。
训练集成模型(训练多个同样的模型然后集成得到更好的泛化效果)是利用并行计算的非常简单的方法,但是当数据集很大种类很多的时候就会产生巨大的计算量而且效果也不好。Hinton在论文中利用soft label的技巧设计了一种集成模型降低了计算量又取得了很好的效果。这个模型包含两种小模型:generalist model 和 specialist model(网络模型相同,分工不同)整个模型由很多个specialist model 和一个generalist model 集成。顾名思义generalist model 是负责将数据进行粗略的区分(将相似的图片归为一类),而specialist model(专家模型)则负责将相似的图片进行更细致的分类。这样的操作也非常符合人类的大脑的思维方式先进行大类的区分再进行具体分类,下面我们看这个实验的具体细节。 实验所用的数据集是谷歌内部的JFT数据集,JFT数据集非常大,有一亿张图片和15000个类别。实验中 generalist model 是用所有数据集进行训练的,有15000个输出,也就是每个类别都有一个输出概率。将数据集进行分类则是用Online k-means聚类的方法对每张图片输入generalist model后得到的软标签进行聚类,最终将3%的数据为一组分发给各个specialist,每个小数据集包含一些聚集的图片,也就是generalist认为相近的图片。 在specialist model的训练阶段,模型的参数在初始化的时候是完全复制的generalist中的数值(specialist和generalist的结构是一模一样的),这样可以保留generalist模型的所有知识,然后specialist对分配的数据集进行hard label训练。但是问题是,specialist如果只专注于分配的数据集(只对分配的数据集训练)整个网络很快就会过拟合于分配的数据集上,所以Hinton提出的方法是用一半的时间进行hard label训练,另一半的时间用知识蒸馏的方法学习generalist生成的soft label。这样specialist就是花一半的时间在进行小分类的学习,另一半的时间是在模仿generalist的行为。 整个模型的预测也与往常不同。在做top-1分类的时候分为以下两步: 第一步:将图片输入generalist model 得到输出的概率分布,取概率最大的类别k。 第二步:取出数据集包含类别k的所有specialists,为集合(各个数据集之间是有类别重合的)。然后求解能使如下公式最小化的概率分布q作为预测分布。
这里的KL是指KL散度(用于刻画两个概率分布之间的差距)和
分别是测试图片输入generalist 和specialists(m)之后输出的概率分布,累加就是考虑所有属于
集合的specialist的“意见”。
由于Specialist model的训练数据集很小,所以需要训练的时间很短,从传统方法需要的几周时间减少到几天。下图是在训练好generalist模型之后逐个增加specialist进行训练的测试结果:
从图中可以看出,specialist个数的增加使top1准确个数有明显的提高。
Tann et al., 2017, Mishra and Marr, 2018和 Polino et al., 2018 将知识蒸馏与量化相结合。 Theis et al., 2018 和 Ashok et al., 2018 将蒸馏与修剪相结合。(论文地址见参考文献)