深度学习论文笔记(知识蒸馏)——Distilling the Knowledge in a Neural Network

文章目录

  • 主要工作
  • motivation
  • method
  • 实验

主要工作

  • 提出一种知识蒸馏的方法,可以压缩模型,让小模型达到与集成亦或是大型模型相似的性能
  • 提出一种新的集成学习方法,可以让模型训练速度更快,并且是并行训练

本文只总结第一点


motivation

大型模型往往不适合线上部署,一方面是计算资源消耗大,另一方面是响应速度慢,因此Hinton便考虑是否可以将大模型的知识迁移到小模型上,这里有两个问题


大型模型知识迁移到小型模型后,小型模型应该具有什么样的表现
最终目的是让小型模型的泛化能力和大型模型一致。


什么是模型的知识呢?

将知识看成是模型参数不合适,这太过复杂且难以入手,结合上一个疑问,个人认为,可以把模型知识抽象为泛化能力,如何让小型模型学得大型模型的泛化能力呢?在此之前,先了解一下什么是soft target。

对于分类模型而言,模型的输出是softmax,例如一个猫狗猪的三分类,对于一张狗的图片,模型输出为
[ 0.1 , 0.89 , 0.01 ] [0.1,0.89,0.01] [0.1,0.89,0.01]
那说明大型模型认为这张图片属于狗,但具有一些猫的分类特征,大型模型的输出含有两方面信息,一是这张图片属于什么类,二是与这张图片相似的类是什么,大型模型的输出又被称为是soft target(也有些变化,计算公式在下面),hard target即为one hot编码,相比于soft target,hard target的信息较少,只能表明一张图片是什么类别。

回到原先的问题,如果让小型模型去拟合大型模型的soft target,以期让小型模型学会像大型模型一样思考,那么小型模型的泛化能力不就可能和大型模型一致了吗?这便是论文的出发点。


method

更具体一点,论文将softmax的输出更改为:
在这里插入图片描述
大型模型与小型模型的输出均采用上式计算,采用上式计算的大模型输出即为soft target

T为超参数,T越大,可以产生更加soft的target,为什么要引入一个超参数呢?首先不论是否引入T,小型模型都是拟合大型模型的输出,简单举个例子,假设大型模型的输出为
[ 0.1 , 0.89 , . 0.01 ] [0.1,0.89,.0.01] [0.1,0.89,.0.01]
小型模型的输出为
[ P 1 ( x ) , P 2 ( x ) , P 3 ( x ) ] [P_1(x),P_2(x),P_3(x)] [P1(x),P2(x),P3(x)]
则交叉熵损失函数为
− [ 0.1 log ⁡ P 1 ( x ) + 0.89 log ⁡ P 2 ( x ) + 0.01 log ⁡ P 3 ( x ) ] -[0.1\log P_1(x)+0.89\log P_2(x)+0.01\log P_3(x)] [0.1logP1(x)+0.89logP2(x)+0.01logP3(x)]
第二项对于损失函数的取值影响很大,如果小型模型的输出为[0.3,0.67,0.03],那么交叉熵损失函数值大致为0.7,如果设置T为2,交叉熵损失函数值大致为1.62,比前者大很多,而小型模型的输出离大型模型还是有差距的,所以设置一个超参数T,可以让小型模型更好的拟合大型模型的输出。

换个角度来看,hard target只能表明图片是什么,而soft target还可表明与该图片相似的类别是什么,knowledge distillation相当于一种监督信息的弥补。

记采用soft target的交叉熵损失函数为 L s o f t L^{soft} Lsoft(小型模型的输出也要除以T后计算softmax),采用hard target的交叉熵损失函数为 L h a r d L^{hard} Lhard,单纯使用soft target的确可以使用不需要标记的数据训练小型模型,但是Hinton发现联合使用两个损失函数效果更佳,如下
L = α L s o f t + ( 1 − α ) L h a r d L=\alpha L^{soft}+(1-\alpha) L^{hard} L=αLsoft+(1α)Lhard

通常 α \alpha α的取值较大,具体流程如下图所示

深度学习论文笔记(知识蒸馏)——Distilling the Knowledge in a Neural Network_第1张图片

实验

有一个实验结果效果非常惊人,如下:
深度学习论文笔记(知识蒸馏)——Distilling the Knowledge in a Neural Network_第2张图片
仅使用了3%的训练数据,便能得到与使用100%训练数据的模型的效果,这说明soft target的监督能力足够强劲

你可能感兴趣的:(深度学习)