知识蒸馏主要处理的是模型的有效性和效率之间的平衡问题:
模型越来越深、越来越复杂,导致模型上线后相应速度太慢,无法满足系统的低延迟要求。
知识蒸馏就是目前一种比较流行的解决此类问题的技术方向。
一般为teacher-student模式,主要思想是用一个复杂的、较大的teacher model去指导简单的、较小的student model的学习。
线上使用的是student小模型。
论文地址:https://arxiv.org/pdf/1503.02531.pdf
Knowledge distillation最早来自于hinton 2015年的一篇论文,在文中hinton提到:可以将一个大的、复杂的或者ensemble的模型获得知识transfer压缩到一个单个小模型中。
其主要思想是将训练好的teacher model输出的class probability作为soft target,让student model去学习:
文中引入了 soft target 这一概念:
q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_i=\frac{exp(z_i/T)}{\sum_jexp(z_j/T)} qi=∑jexp(zj/T)exp(zi/T)
那么为什么要让student去学习这个soft target呢?
因为 soft target 会包含更多的信息:
下图是文中比较的soft target和hard-target的实验结果:
第一行是使用全部数据+hart target训练的baseline,第二行为使用3%数据+hart target的训练和测试accuracy,第三行为使用3%数据+soft target的训练和测试accuracy。
可以发现,仅使用3%的数据+soft target就可以达到和baseline相当的表现。并且,文中提到第二种方式很容易过拟合,必须使用early-stopping,第三种方式不需要使用ealy-stopping,可见 soft target有regularizer的作用。
模型的训练和测试
训练分为两个阶段( H H H为交叉熵):
Rocket Launching: A Universal and Efficient Framework for Training Well-performing Light Net
阿里的rocket launching在KD的基础上做了一些改进,模型结构如下:
训练的loss为:
L ( x ; W S , W L , W B ) = H ( y , p ( x ) ) + H ( y , q ( x ) ) + λ L o s s h i n t L(x;W_S,W_L,W_B) = H(y, p(x)) + H(y, q(x))+\lambda Loss_{hint} L(x;WS,WL,WB)=H(y,p(x))+H(y,q(x))+λLosshint
其中, L o s s h i n t Loss_{hint} Losshint可以有如下几种形式:
Hint loss:
MSE of final softmax: L M S E ( x ) = ∣ ∣ p ( x ) − q ( x ) ∣ ∣ 2 2 L_{MSE}(x)=||p(x)-q(x)||_2^2 LMSE(x)=∣∣p(x)−q(x)∣∣22
MSE of logits before softmax activation: L M I M I C ( x ) = ∣ ∣ l ( x ) − z ( x ) ∣ ∣ 2 2 L_{MIMIC}(x)=||l(x)-z(x)||_2^2 LMIMIC(x)=∣∣l(x)−z(x)∣∣22
knowledge distillation: L K D ( x ) = H ( s o f t m a x ( l ( x ) T ) , s o f t m a x ( z ( x ) / T ) ) L_{KD}(x)=H(softmax(\frac{l(x)}{T}), softmax(\frac{z(x)}{/T})) LKD(x)=H(softmax(Tl(x)),softmax(/Tz(x)))
Rocket launching主要有如下几处改进:
自己用rocket launching结构,测试的cvr预测任务的测试auc如下(hint loss lambda选的 T 2 T^2 T2):增大T,确实能够使得student的测试表现有所提升。
论文地址:https://arxiv.org/pdf/1412.6550.pdf
核心思想是利用teacher model中间层输出指导student model中间层的输出,获得一个thin 但是 deeper 的student model,因为一般深层的神经网络表达力更强,可以获得更加抽象的特征表征。
All previous work focuses on compressing a teacher network or an ensemble of networks into either networks of similar width and depth or into shallower and wider ones; not taking advantage of depth
allow the training of a student that is deeper and thinner than the teacher, using not only the outputs but also the intermediate representations learned by the teacher as hints to improve the training process and final performance of the student.
Hint layer和guided layer的定义:
A hint is defined as the output of a teacher’s hidden layer responsible for guiding the student’s learning process.
Analogously, we choose a hidden layer of the FitNet, the guided layer, to learn from the teacher’s hint layer.
We want the guided layer to be able to predict the output of the hint layer.
注意在student训练过程汇中引入hint是一种正则化的方式,hint/guided layer选层数越深,student网络训练的灵活性就越少,就会容易over-regularization。文中的hint/guided layer选的都是网络的中间层。
该部分详细可见张俊林博士的文章: https://zhuanlan.zhihu.com/p/143155437
精排、粗排以及模型召回环节都可以采用知识蒸馏技术来优化现有推荐系统的性能和效果。
对于精排,知识蒸馏适用于如下两种技术转换场景:
知识蒸馏应用在召回/粗排环节是比较“合算”的,因为这两个环节本身,并不追求最高的推荐精度,而模型小,速度快则是模型召回及粗排的重要目标之一,这与知识蒸馏的特点正好相符合。
召回/粗排环节的知识蒸馏可以使用两阶段的方式训练,因为teacher model可以直接用精排环节已经训练好的模型。
此外,student model可以去学习模型的logits,也可以去学习精排模型的排序偏好。