看论文之知识蒸馏

该部分为对论文"Looking for the Devil in the Details: Learning Trilinear Attention Sampling

Network for Fine-grained Image Recognition "中的第三部分知识蒸馏的了解。使用知识蒸馏来压缩模型。

1.什么是知识蒸馏?

(1)化学中:蒸馏是一种有效的 分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的。各组分沸点不同,蒸馏时要根据目标物质的沸点设置蒸馏温度。
(2)深度学习中:一般地,大模型往往是单个复杂网络或者是若干网络的集合,拥有良好的性能和泛化能力,而小模型因为网络规模较小,表达能力有限。利用大模型学习到的知识去 指导小模型训练,使得 小模型具有与大模型相当的性能,但是 参数数量大幅降低,从而 实现模型压缩与加速,这就是知识蒸馏与迁移学习在模型优化中的应用。
看论文之知识蒸馏_第1张图片

Hinton等人最早在文章《Distilling the Knowledge in a Neural Network》中提出了知识蒸馏这个概念,其核心思想是先训练一个复杂网络模型,然后使用这个复杂网络的输出数据的真实标签去训练一个更小的网络,因此知识蒸馏框架通常包含了一个复杂模型(被称为Teacher模型)和一个小模型(被称为Student模型)。

通过设置loss函数来进行迭代更新,使得小模型的学习达到一个与真实标签相近的结果。

2.知识蒸馏的作用

提升模型精度、降低模型时延、压缩网络参数、标签之间的域迁移等。

3.知识蒸馏

知识蒸馏是对模型的能力进行迁移,根据迁移的方法不同可以简单分为基于目标蒸馏(也称为Soft-target蒸馏或Logits方法蒸馏)和基于特征蒸馏的算法两个大的方向,下面我们对其进行介绍。

一个很直白且高效的迁移泛化能力的方法就是:使用softmax层输出的类别的概率来作为“Soft-target” 。

3.1 Hard-target 和Soft-target

传统的神经网络训练方法是定义一个损失函数,目标是使预测值尽可能接近于真实值(Hard- target),损失函数就是使神经网络的损失值和尽可能小。这种训练过程是对ground truth求极大似然。在知识蒸馏中,是使用大模型的类别概率作为Soft-target的训练过程。

看论文之知识蒸馏_第2张图片

  • Hard-target:原始数据集标注的 one-shot 标签,除了正标签为 1,其他负标签都是 0。

  • Soft-target:Teacher模型softmax层输出的类别概率,每个类别都分配了概率,正标签的概率最高。

3.2 知识蒸馏为什么有用?

  1. softmax层的输出,除了正例之外,负标签也带有Teacher模型归纳推理的大量信息,比如某些负标签对应的概率远远大于其他负标签,则代表 Teacher模型在推理时认为该样本与该负标签有一定的相似性。而在传统的训练过程(Hard-target)中,所有负标签都被统一对待。也就是说,知识蒸馏的训练方式使得每个样本给Student模型带来的信息量大于传统的训练方式。
  2. Soft-target分布的熵相对高时,其Soft-target蕴含的知识就更丰富。
  3. 使用 Soft-target 训练时,梯度的方差会更小,训练时可以使用更大的学习率,所需要的样本也更少。
这也解释了为什么通过蒸馏的方法训练出的Student模型相比使用完全相同的模型结构和训练数据只使用Hard-target的训练方法得到的模型,拥有更好的泛化能力。

3.3 温度T

直接使用softmax层的输出值作为soft target,这又会带来一个问题: 当softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小,小到可以忽略不计。因此"温度"这个变量就派上了用场。

T越高,softmax的output probability distribution越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。

看论文之知识蒸馏_第3张图片

 温度的高低改变的是Student模型训练过程中对负标签的关注程度。当温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Student模型会相对更多地关注到负标签。

  • 当想从负标签中学到一些信息量的时候,温度T应调高一些;

  • 当想减少负标签的干扰的时候,温度T应调低一些;

总的来说, T的选择和Student模型的大小有关,Student模型参数量比较小的时候,相对比较低的温度就可以了。因为参数量小的模型不能学到所有Teacher模型的知识,所以可以适当忽略掉一些负标签的信息。

最后,在整个知识蒸馏过程中,我们先让温度T升高,然后在测试阶段恢复“低温“,从而将原模型中的知识提取出来,因此将其称为是蒸馏,实在是妙啊。

3.4.知识蒸馏训练的步骤

  1. 训练好Teacher模型;
  2. 利用高温产生T=t产生Soft-target;
  3. 使用T=t和T=1同时训练Student模型;
  4. 设置温度T,Student模型线上做inference

看论文之知识蒸馏_第4张图片

训练Teacher的过程很简单,我们把第2步和第3步过程统一称为:高温蒸馏的过程。高温蒸馏过程的目标函数由distill loss(对应Soft-target)和Student loss(对应Hard-target)加权得到。当权重较小时,能产生最好的效果,这是一个经验性的结论。

  • distill loss:用Teacher模型在高温t下产生的softmax distribution来作为Soft-target,Student模型在相同温度t条件下的softmax输出和Soft-target的cross entropy就是Loss函数的第一部分.
  •  Student模型在T=1的条件下的softmax输出和ground truth的cross entropy就是Loss函数的第二部分 。Teacher模型也有一定的错误率,使用ground truth可以有效降低错误被传播给Student模型的可能性。

PS:

H(p,q)=-\sum_xp(x)logq(x)

交叉熵刻画的是两个概率分布之间的距离,或可以说它刻画的是通过概率分布q来表达概率分布p的困难程度,p代表正确答案,q代表的是预测值,交叉熵越小,两个概率的分布约接近。

本文是阅读一下博客写的小笔记:

https://blog.csdn.net/Kaiyuan_sjtu/article/details/深度学习中的知识蒸馏技术(上)谢谢!

你可能感兴趣的:(深度学习,人工智能,神经网络)