全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)

一.是什么?

把一个大的模型(定义为教师模型)萃取,蒸馏,把它浓缩到小的模型(定义为学生模型)。

即:大的神经网络把他的知识教给了小的神经网络。

全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第1张图片

二.为什么要用知识蒸馏把大模型学习到的东西迁移到小模型呢呢?

因为大的模型很臃肿,而真正落地的终端算力有限,比如手表,安防终端。
所以要把大模型变成小模型,把小模型部署到终端上。

全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第2张图片

2.1 轻量化网络的方向

分为下面四个方向,知识蒸馏是第一个方向。

全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第3张图片

三.用蒸馏温度处理学生网络的标签

学生网络有两种标签:

一种是教师网络的输出,
一种是真实的标签。

3.1 soft target

soft target使我们常用的概率版的标签值。比如输入下面的图片预测。

hard targets和soft targets的预测概率如下:
全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第4张图片
hard targets的预测结果不科学,因为马和驴比马和汽车相似的多。所以驴和汽车都是0,没有表现出这个信息,所以要用soft targets.

3.2 用教师网络预测出的soft target作为学生网络的标签。

教师网络预测出的soft target具有很多信息。

3.3 蒸馏温度

softmax有放大差异的功能。
如果值高那么一点点,经过softmax的放大就会变得很高。
如果想让soft target更加平缓,高的降低,低的升高。
这时就要对soft target使用蒸馏温度。 让soft target更soft。
实现方法是在softmax的分母处加个T。

全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第5张图片

效果如下:T=1时相当于没有蒸馏温度。T=3时确实低的更低高的更高了。

全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第6张图片
全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第7张图片

T和分布的关系如下图,T从1增加到10,值之间的差异越来越小,softmax的放大效果被冲淡。
当T=100的时候,结果直接变成一个横线,众生平等。

全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第8张图片

四.知识蒸馏训练过程

4.1 图示知识蒸馏训练过程

全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第9张图片
全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第10张图片

全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第11张图片

上面是已经训练好的教师网络。
把数据输入到教师网络,在输出时使用蒸馏温度为T的softmax.
再把数据输入到学生网络,学生网络可能是还没有训练的网络,也可能是训练一半的半成品网络。  

4.2 损失函数

学生网络既要在蒸馏温度等于T时与教师网络的结果相接近。
也要保证不使用蒸馏温度时的结果与真实结果相接近。

蒸馏损失:

把教师网络使用蒸馏温度为t的输出结果 与 学生网络蒸馏温度为t的输出结果做损失。
让这个损失越小越好。

学生损失:

学生网络蒸馏温度为1(即不使用蒸馏网络)时的预测结果和真实的标签做loss.

最后对这两项加权求和。

4.3 图解损失函数计算过程

红色线条指向的是学生损失。
紫色线条指向的是蒸馏损失。

全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第12张图片

五.推理过程

此时学生网络已经训练好,把X输入到学生网络得到结果。
全网最细图解知识蒸馏(涉及知识点:知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)_第13张图片

六.最终效果:

学生网络可以接近教师网络的识别效果,并且附加如下两个特点:

1.零样本识别

论文里面说:以手写体数字数据集为例,假如在训练学生网络时把标签为3的类别全部去掉,
但是教师网络学过3。当使用知识蒸馏将教师网络学到的东西迁移到学生网络时,学生网络虽然没有见过3,但是却能识别3,即达到了零样本的效果。

2.使用soft target训练而不是hard target,减少了过拟合

在这里插入图片描述

第二行和第三行是使用百分之3的训练样本并分别用hard target和soft target,结果显示

使用3%的训练样本 + hard target :
训练集的准确率为 67.3%, 测试集的准确率为44.5%
使用3%的训练样本 + soft target :
训练集的准确率为 65.4%, 测试集的准确率为57.5%

七.迁移学习和知识蒸馏的区别

迁移学习是把一个模型学习的领域泛化到另一个领域,比如把猫狗这些动物域迁移到医疗域。
知识蒸馏是把一个模型的知识迁移到另一个模型上。

八.参考视频

B站UP主,同济子豪兄的视频:
【精读AI论文】知识蒸馏
https://www.bilibili.com/video/BV1gS4y1k7vj/?spm_id_from=333.788&vd_source=ebc47f36e62b223817b8e0edff181613

你可能感兴趣的:(机器学习&深度学习笔记,深度学习,人工智能)