知识蒸馏Knowledge Distillation

知识蒸馏是模型压缩的一个重要方法,本文简要介绍了什么是知识蒸馏。

知识蒸馏Knowledge Distillation

1.什么是知识蒸馏

我浅谈一些我的看法,详细内容可以参考这篇文章
[https://zhuanlan.zhihu.com/p/90049906]

简单来说,就是我们一般训练模型时,可能为了有一个好的效果,就会加大网络深度,或者用一些复杂的网络,这样参数量就会好大…那么这么一个模型怎么弄到移动端呢?怎么能使得运行速度实时呢?

所以人们提出搞模型压缩!

知识蒸馏就是一类模型压缩方法,先训练大模型,再去引导做一个小模型。

它是怎么干的呢?

大模型会训练出一系列的softmax概率值,这样,原来我们需要让新模型的softmax分布与真实标签匹配,现在只需要让新模型与原模型在给定输入下的softmax分布匹配了。直观来看,后者比前者具有这样一个优势:经过训练后的原模型,其softmax分布包含有一定的知识——真实标签只能告诉我们,某个图像样本是一辆宝马,不是一辆垃圾车,也不是一颗萝卜;而经过训练的softmax可能会告诉我们,它最可能是一辆宝马,不大可能是一辆垃圾车,但绝不可能是一颗萝卜。

随后怎么做,简单来说就是可以小模型的概率z要逼近原模型的v,直接用下面损失也可以
在这里插入图片描述

2.知识蒸馏怎么做

  • 第一步:在训练集上训练好一个大模型A(通常叫做teacher model)
  • 第二步:在transfer set(可以和训练集是同一个数据集)上利用大模型A产生给每一个样本生成一个soft target(有利用一个temperature参数对logits进行平滑)
  • 第三步:在transfer set上对student model B进行训练,损失函数由两部分组成,都是交叉熵损失,只不过一个是拟合soft target,另外一个是拟合ground truth的hard target(如二分类中的0和1),其中在拟合hard target的损失函数和普通分类损失保持一致,在拟合soft target的损失函数时也利用了一个同样的temperature参数T
  • 第四步:保留student model进行线上预测,这个时候去掉soft target那一路,只保留普通分类的softmax
    知识蒸馏Knowledge Distillation_第1张图片

参考文献

[1]Distilling the Knowledge in a Neural Network论文笔记
https://zhuanlan.zhihu.com/p/74901192
[2]知识蒸馏是什么?一份入门随笔
https://zhuanlan.zhihu.com/p/90049906

你可能感兴趣的:(计算机视觉CV)