知识蒸馏介绍

知识蒸馏介绍

知识蒸馏介绍_第1张图片

A是效果比较好的大模型,但不适合部署在计算资源有限的小型设备上,可以用知识蒸馏的方法训练一个高效的小模型B。通常只应用于分类任务,且学生只能从头学起

知识蒸馏介绍_第2张图片

知识蒸馏介绍_第3张图片

知识蒸馏可以分为输出值知识蒸馏和特征值知识蒸馏

知识蒸馏介绍_第4张图片

 

小模型预测结果为[1,0,0],[0,7,0.29,0.01]是教师模型的,而另外两个图片概率为[0.29,0.01] ,也可能是西红柿和香蕉。这个信息对学生是有意义的,我们希望学生能从图片中提取到其他隐含知识。从老师模型中获得这个信息的过程叫做知识蒸馏。

实现知识蒸馏需要的函数

知识蒸馏介绍_第5张图片

 

知识蒸馏介绍_第6张图片

知识蒸馏介绍_第7张图片

log_softmax: 对输入值先算softmax,然后再对softmax的所有输出取一个ln,因为softmax的输出值都在[0,1]之间,所以log_softmax的输出结果都为负数,越接近0代表概率越大。

知识蒸馏介绍_第8张图片

 

nllloss: 输出值为log_softmax中对应真实标签的值的负数。比如log_softmax输出为[-1.2,-2,-3],真实标签为[0],那么nllloss为1.2。nllloss值越大代表误差越大。因为概率越小,log_softmax的输出值也会越小,对应的nllloss值就会越大。

log_softmax和nll_loss一起用的例子:

知识蒸馏介绍_第9张图片

 

 

p是一个one-hot形式的向量, q是表示概率的向量

知识蒸馏介绍_第10张图片

 

可以看到,直接对input用cross_entropy 和对input先用log_softmax,再用nll_loss 的输出结果是一样的

温度:

设t为温度,softmax函数会被改造,计算每个输出值时会由    变成  ,这会使输出的值更平均,不会过于极端,例如一个值靠近1,其他值都靠近0 

知识蒸馏介绍_第11张图片

可以看到,温度t=10时输出值更平均了,t=10000时,输出值基本都为0.33

蒸馏:

知识蒸馏介绍_第12张图片

 

t=1时,softmax层没有被影响,output向量中的元素[5]为正数,其他都为负数,所以经过softmax层后差异很大,认为[5]的概率接近1,其他元素概率都接近0

t=10时,softmax层被改造了,每个元素概率的差异缩小了。知识蒸馏就是想把下面那个概率图教给学生网络,让学生网络知道这张图片虽然很大概率是[5],但也有有小概率可能是[3]或[6],这种蒸馏出来的知识叫做‘暗知识’

老师-学生网络训练流程

知识蒸馏介绍_第13张图片

 

我们输入的是 一张图片x和它one-hot形式的真实值y,学生网络的输出经过softmax层得到p,然后计算p和y之间的交叉熵LossHARD ,我们希望y越接近p越好,这也是原始的方法。

而知识蒸馏会将教师网络输出经过温度t的蒸馏后得到s,再计算s和学生网络的交叉熵LossSOFT。最终的total loss是这两个Loss的带权相加。这样做的目的是我们一方面希望学生网络输出能接近真实值,另一方面也希望学到老师的知识。

参考:1. https://www.bilibili.com/video/BV1s7411h7K2

           2. https://www.bilibili.com/video/BV1SC4y1h7HB?p=7  

           3. https://zhuanlan.zhihu.com/p/387275923

你可能感兴趣的:(python,计算机视觉,目标检测,深度学习)