知识蒸馏介绍
A是效果比较好的大模型,但不适合部署在计算资源有限的小型设备上,可以用知识蒸馏的方法训练一个高效的小模型B。通常只应用于分类任务,且学生只能从头学起
知识蒸馏可以分为输出值知识蒸馏和特征值知识蒸馏
小模型预测结果为[1,0,0],[0,7,0.29,0.01]是教师模型的,而另外两个图片概率为[0.29,0.01] ,也可能是西红柿和香蕉。这个信息对学生是有意义的,我们希望学生能从图片中提取到其他隐含知识。从老师模型中获得这个信息的过程叫做知识蒸馏。
实现知识蒸馏需要的函数
log_softmax: 对输入值先算softmax,然后再对softmax的所有输出取一个ln,因为softmax的输出值都在[0,1]之间,所以log_softmax的输出结果都为负数,越接近0代表概率越大。
nllloss: 输出值为log_softmax中对应真实标签的值的负数。比如log_softmax输出为[-1.2,-2,-3],真实标签为[0],那么nllloss为1.2。nllloss值越大代表误差越大。因为概率越小,log_softmax的输出值也会越小,对应的nllloss值就会越大。
log_softmax和nll_loss一起用的例子:
p是一个one-hot形式的向量, q是表示概率的向量
可以看到,直接对input用cross_entropy 和对input先用log_softmax,再用nll_loss 的输出结果是一样的
温度:
设t为温度,softmax函数会被改造,计算每个输出值时会由 变成
,这会使输出的值更平均,不会过于极端,例如一个值靠近1,其他值都靠近0
可以看到,温度t=10时输出值更平均了,t=10000时,输出值基本都为0.33
蒸馏:
t=1时,softmax层没有被影响,output向量中的元素[5]为正数,其他都为负数,所以经过softmax层后差异很大,认为[5]的概率接近1,其他元素概率都接近0
t=10时,softmax层被改造了,每个元素概率的差异缩小了。知识蒸馏就是想把下面那个概率图教给学生网络,让学生网络知道这张图片虽然很大概率是[5],但也有有小概率可能是[3]或[6],这种蒸馏出来的知识叫做‘暗知识’
我们输入的是 一张图片x和它one-hot形式的真实值y,学生网络的输出经过softmax层得到p,然后计算p和y之间的交叉熵LossHARD ,我们希望y越接近p越好,这也是原始的方法。
而知识蒸馏会将教师网络输出经过温度t的蒸馏后得到s,再计算s和学生网络的交叉熵LossSOFT。最终的total loss是这两个Loss的带权相加。这样做的目的是我们一方面希望学生网络输出能接近真实值,另一方面也希望学到老师的知识。
参考:1. https://www.bilibili.com/video/BV1s7411h7K2
2. https://www.bilibili.com/video/BV1SC4y1h7HB?p=7
3. https://zhuanlan.zhihu.com/p/387275923