本文介绍下TinyBERT,华为在2020发布的一篇论文,主要内容是对模型进行蒸馏,蒸馏的方法值得学习
论文地址:
https://arxiv.org/abs/1909.10351
代码地址:
https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT
General Distillation
和 Task-specific Distillation
,也就是在大规模语料上对通识知识的蒸馏,这是在预训练阶段的蒸馏,和在指定任务数据上对特定任务的知识蒸馏,并且用通识知识的蒸馏模型对指定任务的蒸馏模型进行初始化,这是在微调阶段的蒸馏。同时,在特征任务上进行知识蒸馏时,会先对数据进行增强。
作者的实验结果是,4层的tinybert可以达到bertbase的96.8%的效果,但是参数量为bertbase的13.3%,推理时间为10.6%,并且比其他蒸馏的效果要好,同时,6层的tinybert和bertbase的表现近似。
1、transformer distillation
对transformer网络层数的蒸馏。假设学生模型有M层,老师模型有N层,自定义一个map函数 n = g ( m ) n=g(m) n=g(m),实现学生层到老师层的map,表示学生模型的第m层从老师模型的第g(m)层学得信息。损失函数如下:
L m o d e l = ∑ x ∈ X ∑ m = 0 M + 1 λ m L l a y e r ( f m S ( x ) , f g ( m ) T ( x ) ) L_{model}=\sum_{x\in_{X}}\sum_{m=0}^{M+1}\lambda_mL_{layer}(f_m^S(x),f_{g(m)}^T(x)) Lmodel=∑x∈X∑m=0M+1λmLlayer(fmS(x),fg(m)T(x))
L l a y e r L_{layer} Llayer表示的是某一个transformer layer或者是embedding layer的损失函数, f m ( x ) f_m(x) fm(x)表示第m层的目标函数值, λ m \lambda_m λm表示第m层的重要性,为超参数。
transformer distillation包括attention distill、hidden distill、embedding distill、以及prediction distill,如下图所示:
attention
其中,attention distill的目标函数为:
L a t t n = 1 h ∑ i = 1 h M S E ( A i S , A i T ) L_{attn}=\frac{1}{h}\sum_{i=1}^hMSE(A_i^S,A_i^T) Lattn=h1∑i=1hMSE(AiS,AiT)
h表示注意力头的个数, A i A_i Ai表示学生或老师第i个注意力头的attention matrix
同时,作者表明,之所以使用 A i A_i Ai,而不是 s o f t m a x ( A i ) softmax(A_i) softmax(Ai)作为拟合目标,是因为前者的收敛更快,效果更好。
hidden
其中,transformer输出蒸馏的损失函数为:
L h i d n = M S E ( H S W h , H T ) L_{hidn}=MSE(H^SW_h,H^T) Lhidn=MSE(HSWh,HT)
其中 H S ∈ R l × d ∗ H^S\in{R^{l \times d^*}} HS∈Rl×d∗, H T ∈ R l × d H^T\in{R^{l \times d}} HT∈Rl×d
, d ∗ d^* d∗表示学生模型的向量维度。 W h W_h Wh是一个可学习矩阵,用来对学生模型进行线性变化,将其转化为与老师模型相同的维度。
embedding
embedding层输出蒸馏的损失函数为:
L e m b d = M S E ( E S W e , E T ) L_{embd}=MSE(E^SW_e,E^T) Lembd=MSE(ESWe,ET)
可以看到基本与transformer输出的蒸馏形式是一样的。
prediction
L p r e d = C E ( z T t , z S t ) L_{pred}=CE(\frac{z^T}{t},\frac{z^S}{t}) Lpred=CE(tzT,tzS)
z表示logits,t表示温度系数,作者实验发现,t=1时效果最好。这部分的损失函数就和distillbert设计的蒸馏损失比较像.
整体模型的损失函数如下:
L l a y e r = { L e m b d m = 0 L h i d n + L a t t n m ∈ ( 0 , M ] L p r e d m = M + 1 L_{layer}=\begin{cases} L_{embd} & m=0 \\ L_{hidn}+L_{attn} & m\in(0,M] \\ L_{pred} & m=M+1 \end{cases} Llayer=⎩⎪⎨⎪⎧LembdLhidn+LattnLpredm=0m∈(0,M]m=M+1
其中,m表示学生的层数
2、task-specific distillation
该部分先对数据集进行增强,然后进行蒸馏。作者对数据增强的解释为,学生模型在经过增强的数据集上进行训练,可以提高其效果,也就是说,相比于老师模型,学生模型在特定任务上的训练数据是经过增强的,以此来提升学生模型的效果,因此学生就有超过老师的可能。
作者结合bert和glove的词嵌入,在word-level上进行替换,以实现数据增强。作者的参数设置如下, p t = 0.4 p_t=0.4 pt=0.4, N a = 20 N_a=20 Na=20, K = 15 K=15 K=15
论文并没有对task-specific distillation的蒸馏部分进行阐述,说明其与general distill的蒸馏方式应该是一样的,只是一个处于预训练阶段,一个处于微调阶段。
3、实验结果
实验时,作者使用 g ( m ) = 3 × m g(m)=3 \times m g(m)=3×m进行映射,也就是说4层的tinybert的每层都是从3层的bertbase中学得。
下面是作者对tinybert使用得学习策略和蒸馏方式做的消融实验:
下面是作者针对学生层到老师层的映射做的消融实验:
可以看到,使用均匀映射的效果是最好的,同时,作者也表明,对于一个下游任务,自适应的选择层数是一个具有挑战性的问题,也是未来的工作方向。
每个蒸馏,又会进行以下操作:
之所以做数据增强,是为了在对具体任务蒸馏时,扩充学生模型的训练集,提高学生模型的表现。