转自:【经典简读】知识蒸馏(Knowledge Distillation) 经典之作
作者:潘小小
知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方法,由于其简单,有效,在工业界被广泛应用。这一技术的理论来自于2015年Hinton发表的一篇神作:Distilling the Knowledge in a Neural Network。Knowledge Distillation,简称KD,顾名思义,就是将已经训练好的模型包含的知识(”Knowledge”),蒸馏(“Distill”)提取到另一个模型里面去。今天,我们就来简单读一下这篇论文,力求用简单的语言描述论文作者的主要思想。
在本文中,我们将从背景和动机讲起,然后着重介绍“知识蒸馏”的方法,最后我会讨论“温度“这个名词:
介绍
知识蒸馏的理论依据
知识蒸馏的具体方法
关于"温度"的讨论
参考
虽然在一般情况下,我们不会去区分训练和部署使用的模型,但是训练和部署之间存在着一定的不一致性:
在训练过程中,我们需要使用复杂的模型,大量的计算资源,以便从非常大、高度冗余的数据集中提取出信息。在实验中,效果最好的模型往往规模很大,甚至由多个模型集成得到。而大模型不方便部署到服务中去,常见的瓶颈如下:
在部署时,我们对延迟以及计算资源都有着严格的限制。
因此,模型压缩(在保证性能的前提下减少模型的参数量)成为了一个重要的问题。而”模型蒸馏“属于模型压缩的一种方法。
插句题外话,我们可以从模型参数量和训练数据量之间的相对关系来理解underfitting和overfitting。AI领域的从业者可能对此已经习以为常,但是为了力求让小白也能读懂本文,还是引用我同事的解释(我印象很深)形象地说明一下:
模型就像一个容器,训练数据中蕴含的知识就像是要装进容器里的水。当数据知识量(水量)超过模型所能建模的范围时(容器的容积),加再多的数据也不能提升效果(水再多也装不进容器),因为模型的表达空间有限(容器容积有限),就会造成underfitting;而当模型的参数量大于已有知识所需要的表达空间时(容积大于水量,水装不满容器),就会造成overfitting,即模型的variance会增大(想象一下摇晃半满的容器,里面水的形状是不稳定的)。
上面容器和水的比喻非常经典和贴切,但是会引起一个误解: 人们在直觉上会觉得,要保留相近的知识量,必须保留相近规模的模型。也就是说,一个模型的参数量基本决定了其所能捕获到的数据内蕴含的“知识”的量。
这样的想法是基本正确的,但是需要注意的是:
知识蒸馏使用的是Teacher—Student模型,其中teacher是“知识”的输出者,student是“知识”的接受者。知识蒸馏的过程分为2个阶段:
在本论文中,作者将问题限定在分类问题下,或者其他本质上属于分类问题的问题,该类问题的共同点是模型最后会有一个softmax层,其输出值对应了相应类别的概率值。
如果回归机器学习最最基础的理论,我们可以很清楚地意识到一点(而这一点往往在我们深入研究机器学习之后被忽略): 机器学习最根本的目的在于训练出在某个问题上泛化能力强的模型。
而现实中,由于我们不可能收集到某问题的所有数据来作为训练数据,并且新数据总是在源源不断的产生,因此我们只能退而求其次,训练目标变成在已有的训练数据集上建模输入和输出之间的关系。由于训练数据集是对真实数据分布情况的采样,训练数据集上的最优解往往会多少偏离真正的最优解(这里的讨论不考虑模型容量)。
而在知识蒸馏时,由于我们已经有了一个泛化能力较强的Net-T,我们在利用Net-T来蒸馏训练Net-S时,可以直接让Net-S去学习Net-T的泛化能力。
一个很直白且高效的迁移泛化能力的方法就是:使用softmax层输出的类别的概率来作为“soft target”。
【KD的训练过程和传统的训练过程的对比】
KD的训练过程为什么更有效?
softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。而在传统的训练过程(hard target)中,所有负标签都被统一对待。也就是说,KD的训练方式使得每个样本给Net-S带来的信息量大于传统的训练方式。
【举个例子】
在手写体数字识别任务MNIST中,输出类别有10个。假设某个输入的“2”更加形似"3",softmax的输出值中"3"对应的概率为0.1,而其他负标签对应的值都很小,而另一个"2"更加形似"7","7"对应的概率为0.1。这两个"2"对应的hard target的值是相同的,但是它们的soft target却是不同的,由此我们可见soft target蕴含着比hard target多的信息。并且soft target分布的熵相对高时,其soft target蕴含的知识就更丰富。
这就解释了为什么通过蒸馏的方法训练出的Net-S相比使用完全相同的模型结构和训练数据只使用hard target的训练方法得到的模型,拥有更好的泛化能力。
先回顾一下原始的softmax函数:
q i = exp ( z i ) ∑ j exp ( z j ) q_i=\frac{\exp(z_i)}{\sum_j\exp(z_j)} qi=∑jexp(zj)exp(zi)
但要是直接使用softmax层的输出值作为soft target, 这又会带来一个问题: 当softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小,小到可以忽略不计。因此**“温度”**这个变量就派上了用场。
下面的公式时加了温度这个变量之后的softmax函数:
q i = exp ( z i / T ) ∑ j exp ( z j / T ) q_i=\frac{\exp(z_i/T)}{\sum_j\exp(z_j/T)} qi=∑jexp(zj/T)exp(zi/T)
训练Net-T的过程很简单,下面详细讲讲第二步:高温蒸馏的过程。高温蒸馏过程的目标函数由distill loss(对应soft target)和student loss(对应hard target)加权得到。示意图如上。
L = α L s o f t + β L h a r d L=\alpha L_{soft}+\beta L_{hard} L=αLsoft+βLhard
v i v_i vi: Net-T的logits
z i z_i zi: Net-S的logits
p i T p_i^T piT: Net-T的在温度=T下的softmax输出在第i类上的值
q i T q_i^T qiT: Net-S的在温度=T下的softmax输出在第i类上的值
c i c_i ci: 在第i类上的ground truth值, ci∈{0,1}, 正标签取1,负标签取0.
N N N: 总标签数量
Net-T 和 Net-S 同时输入 transfer set (这里可以直接复用训练Net-T用到的training set), 用Net-T产生的softmax distribution (with high temperature) 来作为soft target,Net-S在相同温度T条件下的softmax输出和soft target的cross entropy就是Loss函数的第一部分 L s o f t L_{soft} Lsoft:
L s o f t = − ∑ j N p j T log ( q j T ) L_{soft}=−\sum_j^Np_j^T\log(q_j^T) Lsoft=−j∑NpjTlog(qjT)
其中, p i T = exp ( v i / T ) ∑ k N exp ( v k / T ) , q i T = exp ( z i / T ) ∑ k N exp ( z k / T ) p_i^T=\frac{\exp(v_i/T)}{\sum_k^N\exp(v_k/T)},\ q_i^T=\frac{\exp(z_i/T)}{\sum_k^N\exp(z_k/T)} piT=∑kNexp(vk/T)exp(vi/T), qiT=∑kNexp(zk/T)exp(zi/T) 。
Net-S在T=1的条件下的softmax输出和ground truth的cross entropy就是Loss函数的第二部分 L h a r d L_{hard} Lhard。
L h a r d = − ∑ j N c j log ( q j T = 1 ) L_{hard}=-\sum_j^Nc_j\log(q_j^{T=1}) Lhard=−j∑Ncjlog(qjT=1)
其中, q j T = 1 = exp ( z i ) ∑ k N exp ( z k ) q^{T=1}_j=\frac{\exp(z_i)}{\sum_k^N\exp(z_k)} qjT=1=∑kNexp(zk)exp(zi)
第二部分Loss L h a r d L_{hard} Lhard 的必要性其实很好理解: Net-T也有一定的错误率,使用ground truth可以有效降低错误被传播给Net-S的可能。打个比方,老师虽然学识远远超过学生,但是他仍然有出错的可能,而这时候如果学生在老师的教授之外,可以同时参考到标准答案,就可以有效地降低被老师偶尔的错误“带偏”的可能性。
【讨论】
实验发现第二部分所占比重比较小的时候,能产生最好的结果,这是一个经验的结论。一个可能的原因是,由于soft target产生的gradient与hard target产生的gradient之间有与 T 相关的比值。原论文中只是一笔带过,我在下面补充了一些简单的推导。(ps. 下面推导可能有些错误,如果有读者能够正确推出来请私信我~)
Soft Target
L s o f t = − ∑ j N p j T log ( q j T ) = − ∑ j N z j / T × exp ( v j / T ) ∑ k N exp ( v k / T ) ( 1 ∑ k N exp ( z k / T ) − exp ( z j / T ) ( ∑ k N exp ( z k / T ) ) 2 ) ≈ − 1 T ∑ k N exp ( v k / T ) ( ∑ j N z j exp ( v j / T ) ∑ k N exp ( z k / T ) − ∑ j N z j exp ( z j / T ) exp ( v j / T ) ( ∑ k N exp ( z k / T ) ) 2 ) \begin{aligned} L_{soft}&=-\sum_j^Np_j^T\log(q_j^T)\\ &=-\sum_j^N\frac{z_j/T\times\exp(v_j/T)}{\sum_k^N\exp(v_k/T)}(\frac{1}{\sum_k^N\exp(z_k/T)}-\frac{\exp(z_j/T)}{(\sum_k^N\exp(z_k/T))^2})\\ &\approx-\frac{1}{T\sum_k^N\exp(v_k/T)}(\frac{\sum_j^Nz_j\exp(v_j/T)}{\sum_k^N\exp(z_k/T)}-\frac{\sum_j^Nz_j\exp(z_j/T)\exp(v_j/T)}{(\sum_k^N\exp(z_k/T))^2}) \end{aligned} Lsoft=−j∑NpjTlog(qjT)=−j∑N∑kNexp(vk/T)zj/T×exp(vj/T)(∑kNexp(zk/T)1−(∑kNexp(zk/T))2exp(zj/T))≈−T∑kNexp(vk/T)1(∑kNexp(zk/T)∑jNzjexp(vj/T)−(∑kNexp(zk/T))2∑jNzjexp(zj/T)exp(vj/T))
Hard Target
L h a r d = − ∑ j N c j log ( q j T = 1 ) = − ( ∑ j N c j z j ∑ k N exp ( z k ) − ∑ j N c j z j exp ( z j ) ( ∑ k N exp ( z k ) ) 2 ) L_{hard}=-\sum_j^Nc_j\log(q^{T=1}_j)=-(\frac{\sum_j^Nc_jz_j}{\sum_{k}^N\exp(z_k)}-\frac{\sum_j^Nc_jz_j\exp(z_j)}{(\sum_k^N\exp(z_k))^2}) Lhard=−j∑Ncjlog(qjT=1)=−(∑kNexp(zk)∑jNcjzj−(∑kNexp(zk))2∑jNcjzjexp(zj))
由于 ∂ L s o f t ∂ z i \frac{\partial{L_{soft}}}{\partial{z_i}} ∂zi∂Lsoft 的magnitude大约是 ∂ L h a r d ∂ z i \frac{\partial{L_{hard}}}{\partial{z_i}} ∂zi∂Lhard 的 1 T 2 \frac{1}{T^2} T21 ,因此在同时使用soft target和hard target的时候,需要在soft target之前乘上 T 2 T^2 T2 的系数,这样才能保证soft target和hard target贡献的梯度量基本一致。
【注意】 在Net-S训练完毕后,做inference时其softmax的温度T要恢复到1.
直接match logits指的是,直接使用softmax层的输入logits(而不是输出)作为soft targets,需要最小化的目标函数是Net-T和Net-S的logits之间的平方差。
直接上结论: 直接match logits的做法是 T → ∞ T\rightarrow\infty T→∞ 的情况下的特殊情形。
由单个case贡献的loss,推算出对应在Net-S每个logit z i z_i zi上的gradient:
∂ L s o f t ∂ z i = 1 T ( q i − p i ) = 1 T ( exp ( z i / T ) ∑ j exp ( z j / T ) − exp ( v i / T ) ∑ j exp ( v j / T ) ) \frac{\partial{L_{soft}}}{\partial z_i}=\frac{1}{T}(q_i-p_i)=\frac{1}{T}(\frac{\exp(z_i/T)}{\sum_j\exp(z_j/T)}-\frac{\exp(v_i/T)}{\sum_j\exp(v_j/T)}) ∂zi∂Lsoft=T1(qi−pi)=T1(∑jexp(zj/T)exp(zi/T)−∑jexp(vj/T)exp(vi/T))
当 T → ∞ T\rightarrow \infty T→∞ 时,我们使用 1 + x / T 1+x/T 1+x/T 来近似 exp ( x / T ) \exp(x/T) exp(x/T) ,于是得到
∂ L s o f t ∂ z i ≈ 1 T ( 1 + z i / T N + ∑ j z j / T − 1 + v i / T N + ∑ j v j / T ) \frac{\partial L_{soft}}{\partial z_i}\approx\frac{1}{T}(\frac{1+z_i/T}{N+\sum_jz_j/T}-\frac{1+v_i/T}{N+\sum_jv_j/T}) ∂zi∂Lsoft≈T1(N+∑jzj/T1+zi/T−N+∑jvj/T1+vi/T)
如果再加上 logits 是零均值的假设 ∑ j z j = ∑ j v j = 0 \sum_jz_j=\sum_jv_j=0 ∑jzj=∑jvj=0 。那么上面的公式可以简化成:
∂ L s o f t ∂ z i ≈ 1 N T 2 ( z i − v i ) \frac{\partial L_{soft}}{\partial z_i}\approx\frac{1}{NT^2}(z_i-v_i) ∂zi∂Lsoft≈NT21(zi−vi)
等价于 minimize 以下损失函数
L s o f t ′ = 1 / 2 ( z i − v i ) 2 L'_{soft}={1}/{2}(z_i-v_i)^2 Lsoft′=1/2(zi−vi)2
【问题】 我们都知道“蒸馏”需要在高温下进行,那么这个“蒸馏”的温度代表了什么,又是如何选取合适的温度?如下图所示,随着温度T的增大,概率分布的熵逐渐增大。
在回答这个问题之前,先讨论一下温度T的特点
温度的高低改变的是Net-S训练过程中对负标签的关注程度: 温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Net-S会相对多地关注到负标签。
实际上,负标签中包含一定的信息,尤其是那些值显著高于平均值的负标签。但由于Net-T的训练过程决定了负标签部分比较noisy,并且负标签的值越低,其信息就越不可靠。因此温度的选取比较empirical,本质上就是在下面两件事之中取舍:
总的来说,T的选择和Net-S的大小有关,Net-S参数量比较小的时候,相对比较低的温度就可以了(因为参数量小的模型不能capture all knowledge,所以可以适当忽略掉一些负标签的信息)