Knowledge Distillation(知识蒸馏)

Do Deep Nets Really Need to be Deep?
虽然近年来的趋势如BigGAN,BERT等,动辄上亿参数,几乎就是数据驱动+算力的“暴力”结果。但同时,更加轻量级的升级版模型如ALBERT也能以更少的参数和架构持续刷榜,元学习(meta learning)和零样本学习(Zero-shot learning),还有只需要个位数层数就能取得优异效果的GCN等,都似乎证明了“大道至简”。

深度模型的压缩方法一般有:

  • 参数修剪和共享(parameter pruning and sharing)
  • 低秩因子分解(low-rank factorization)
  • 转移/紧凑卷积滤波器(transferred/compact convolutional filters)
  • 知识蒸馏(knowledge distillation)

Knowledge Distillation
正如“蒸馏”这一主观形象,知识蒸馏的目的就是为了“提纯”。将神经网络轻量化,网络压缩化,以更少的参数得到尽可能相似的网络效果。轻量化的结果可以降低能耗,完成多任务平行等,就可以使一些原先需要在大设备上运行的结构能在小设备上工作,如在手机上,或者嵌入式移动开发板中。那么如何把复杂模型或者多个模型Ensemble(Teacher)学到的知识,迁移到另一个轻量级模型( Student )上,使模型变轻量的同时(方便部署),尽量不损失性能呢?
Knowledge Distillation(知识蒸馏)_第1张图片
Teacher-Student
最早的Knowledge Distillation的思想始于深度学习三巨头之一的Hinton于2014年提出《Distilling the Knowledge in a Neural Network 》,目的是让更浅的层数,更少的参数较少的浅层网络,取得和深层网络相近的效果,以完成网络压缩。具体实现主要是利用学生网络(student network)和教师网络(teacher network),即用一个很大的教师网络给较小的学生进行指导。

如上图一种直观的迁移知识的方式是,将教师网络(big)生成的类别概率向量作为训练学生网络(small)的soft targets(即通过softmax输出而不是one-hot这种真实的hard targets)。这样做,可以让学生网络捕捉到由真实的标签(one-hot)提供的信息,但也包括由教师网络学习更为丰富的数据信息,另一方面soft targets相比于hard targets意义也更多。
T是教师网络,S是学生网络,有:
P T τ = s o f t m a x ( a T τ ) P_T^{\tau}=softmax(\frac{a_T}{^{\tau}}) PTτ=softmax(τaT) P S τ = s o f t m a x ( a S τ ) P_S^{\tau}=softmax(\frac{a_S}{^{\tau}}) PSτ=softmax(τaS)
τ > 1 \tau > 1 τ>1以减轻教师网络输出所产生的信号,以指导学生更多的信息,这就相当于在迁移学习的过程中添加了扰动,从而使得学生网络在借鉴学习的时候更有效、泛化能力更强,这其实就是一种抑制过拟合的策略(取值在1-20)。
学生网络优化的损失函数为:
L K D = H ( y t r u e , P s ) + λ H ( P T τ , P S τ ) L_{KD}=H(y_{true},P_s)+\lambda H(P_T^{\tau},P_S^{\tau}) LKD=H(ytrue,Ps)+λH(PTτ,PSτ)
这里H代表交叉熵, λ \lambda λ用来平衡两个交叉熵。第一项就是一个网络的输出和标签之间传统的交叉熵,而第二项则强制学生网络从教师网络的‘软’输出中学习,尽可能的接近教师。输出作为知识。

为什么软标签能提升信息量?
one-hot的[0.1,0,0,0,0.9]和soft的[0,0,0,0,1]显然信息量是不同的。所以它的优势

  • 弥补了简单分类中监督信号不足(信息熵比较少)的问题,增加了信息量
  • 提供了训练数据中类别之间的关系(数据增强)
  • 增强泛化能力

Knowledge Distillation(知识蒸馏)_第2张图片
青出于蓝而胜于蓝,训练一个更深的student
student虽然深,但不宽,参数仍然小。但是直接训练一个比教师机还深的网络往往会很困难,2015年的《FITNETS:Hints for Thin Deep Nets》通过在中间层加入loss的方法,通过学习teacher中间层feature map来transfer中间层表达的知识,文章中把这个方法叫做Hint-based Training。这样做,学生网络不仅仅拟合教师网络的soft-target,而且会拟合隐藏层的输出(教师抽取的特征)。学习网络中的特征。
具体训练方法如上图,是选取teacher的中间层作为guidance,对student的中间层进行监督学习。由于两者的维度不一样,所以需要一个额外的线性矩阵或卷积层去进行维度变换,达到维度一致,然后使用L2 Loss进行监督。

Knowledge Distillation(知识蒸馏)_第3张图片
授人以鱼不如授人以渔
CVPR2017,《A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning》,寻找网络层之间的关系。
不再利用soft targets或者利用中间特征做hint,而是直接学习每层的特征。如上图在一些组合层间做L2 loss,其中G是低层和高层特征图的channel两两做内积最后得到的矩阵。

Knowledge Distillation(知识蒸馏)_第4张图片
teacher网络是固定的,teacher也需要与时共进
上图出自CVPR2018的《Deep Mutual Learning》,作者认为以往的teacher网络都是固定的,只用来输出soft-target,难以学习student网络中反馈的信息,所以提出深度互学习,用多个学生网络同时训练,通过真值和多个网络的输出结果“相互借鉴,共同进步”(不以模型压缩为主要目的,更多为了提升模型表现)
L 1 = L C 1 + D K L ( p 2 ∣ ∣ p 1 ) L_1=L_{C_1}+D_{KL}(p_2||p_1) L1=LC1+DKL(p2p1) L 2 = L C 2 + D K L ( p 1 ∣ ∣ p 2 ) L2=L_{C_2}+D_{KL}(p_1||p_2) L2=LC2+DKL(p1p2)
思路比较简单, L C 1 L_{C_1} LC1是经典的交叉熵, D K L D_{KL} DKL是KL散度。

似乎GAN也很合适
《KDGAN: Knowledge Distillation with Generative Adversarial Networks》
用对抗生成的方式模拟蒸馏的过程:生成器(学生网络,参数少、简单)负责基于输入X输出X的标签Y,判别器(教师网络,参数多、复杂)判断标签来自于学生网络还是真实的数据

你可能感兴趣的:(深度学习)