预训练模型的提出,比如 BERT,显著的提升了很多自然语言处理任务的表现,它的强大是毫无疑问的。但是他们普遍存在参数过多、模型庞大、推理时间过长、计算昂贵等问题,因此很难落地到实际的产业应用中。TinyBERT是由华中科技大学和华为诺亚方舟实验室联合提出的一种针对transformer-based模型的知识蒸馏方法,以BERT为例对大型预训练模型进行研究。四层结构的 T i n y B E R T 4 TinyBERT_4 TinyBERT4 在 GLUE benchmark 上可以达到 B E R T b a s e BERT_{base} BERTbase 96.8%及以上的性能表现,同时模型缩小7.5倍,推理速度提升9.4倍。六层结构的 T i n y B E R T 6 TinyBERT_6 TinyBERT6 可以达到和 B E R T b a s e BERT_{base} BERTbase 同样的性能表现。
提供一种新的针对 transformer-based 模型进行蒸馏的方法,使得BERT中具有的语言知识可以迁移到TinyBERT中去。
提出一个两阶段学习框架,在预训练阶段和 fine-tuning 阶段都进行蒸馏,确保 TinyBERT 可以充分的从BERT中学习到一般领域和特定任务两部分的知识。
知识蒸馏的目的在于将一个大型的教师网络 T T T 学习到的知识迁移到小型的学生网络 S S S 中。学生网络通过训练来模仿教师网络的行为。 f S f^S fS 和 f T f^T fT 代表教师网络和学生网络的behavior functions。这个行为函数的目的是将网络的输入转化为信息性表示,并且它可被定义为网络中任何层的输出。在基于transformer的模型的蒸馏中,MHA(multi-head attention)层或FFN(fully connected feed-forward network)层的输出或一些中间表示,比如注意力矩阵 A A A 都可被作为行为函数使用。 L K D = ∑ x ∈ X L ( f S ( x ) , f T ( x ) ) L_{KD} = \sum_{x \in X}L(f^S(x), f^T(x)) LKD=x∈X∑L(fS(x),fT(x))其中 L ( ⋅ ) L(⋅) L(⋅) 是一个用于评估教师网络和学生网络之间差异的损失函数, x x x 是输入文本, X X X 代表训练数据集。因此,蒸馏的关键问题在于如何定义行为函数和损失函数。
假设TinyBert有M层transformer layer,teacher BERT有N层transformer layer,则需要从teacher BERT的N层中抽取M层用于transformer层的蒸馏。 n = g ( m ) n=g(m) n=g(m) 定义了一个从学生网络到教师网络的映射关系,表示学生网络中第m层网络信息是从教师网络的第 g ( m ) g(m) g(m) 层学习到的,也就是教师网络的第n层。TinyBERT嵌入层和预测层也是从BERT的相应层学习知识的,其中嵌入层对应的指数为0,预测层对应的指数为M + 1,对应的层映射定义为 0 = g ( 0 ) 0=g(0) 0=g(0) 和 N + 1 = g ( M + 1 ) N+1=g(M+1) N+1=g(M+1)。在形式上,学生模型可以通过最小化以下的目标函数来获取教师模型的知识: L m o d e l = ∑ x ∈ X ∑ M + 1 m = 0 λ m L l a y e r ( f m S ( x ) , f g ( m ) T ( x ) ) L_{model} = \sum_{x \in X}\sum^{M+1}{m=0}\lambda_m L{layer}(f^S_m(x), f^T_{g(m)}(x)) Lmodel=x∈X∑∑M+1m=0λmLlayer(fmS(x),fg(m)T(x))其中 L l a y e r L_{layer} Llayer 是给定的模型层的损失函数(比如transformer层或嵌入层), f m f_m fm 代表第m层引起的行为函数, λ m λ_m λm 表示第m层蒸馏的重要程度。
TinyBERT的蒸馏分为以下三个部分:transformer-layer distillation、embedding-layer distillation、prediction-layer distillation。
Transformer-layer的蒸馏由attention based蒸馏和hidden states based蒸馏两部分组成。
其中,attention based蒸馏是受到论文Clack et al., 2019的启发,这篇论文中提到,BERT学习的注意力权重可以捕获丰富的语言知识,这些语言知识包括对自然语言理解非常重要的语法和共指信息。因此,TinyBERT提出attention based蒸馏,其目的是使学生网络很好地从教师网络处学习到这些语言知识。具体到模型中,就是让TinyBERT网络学习拟合BERT网络中的多头注意力矩阵,目标函数定义如下: L a t t n = 1 h ∑ i = 1 h M S E ( A i S , A i T ) L_{attn} = \frac{1}{h}\sum^{h}_{i=1}MSE(A^S_i, A^T_i) Lattn=h1i=1∑hMSE(AiS,AiT)其中, h h h 代表注意力头数, A i ∈ R l × l A_i \in \mathbb{R}^{l\times l} Ai∈Rl×l 代表学生或教师的第 i 个注意力头对应的注意力矩阵, l l l 代表输入文本的长度。论文中提到,使用注意力矩阵 A A A 而不是 s o f t m a x ( A ) softmax(A) softmax(A) 是因为实验结果显示这样可以得到更快的收敛速度和更好的性能表现。
hidden states based蒸馏是对transformer层输出的知识进行了蒸馏处理,目标函数定义为: L h i d n = M S E ( H S W h , H T ) L_{hidn} = MSE(H^SW_h, H^T) Lhidn=MSE(HSWh,HT)其中, H S ∈ R l × d ′ , H T ∈ R l × d H^S \in \mathbb{R}^{l \times d^{'}},\quad H^T \in \mathbb{R}^{l \times d} HS∈Rl×d′,HT∈Rl×d 分别代表学生网络和教师网络的隐状态,是FFN的输出。 d d d 和 d ′ d' d′代表教师网络和学生网络的隐藏状态大小,且 d ′ < d d'
L e m b d = M S E ( E S W e , E T ) L_{embd} = MSE(E^SW_e, E^T) Lembd=MSE(ESWe,ET)Embedding loss和hidden states loss同理,其中 E S , E T E^S,E^T ES,ET 代表学生网络和教师网络的嵌入,他呢和隐藏状态矩阵的形状相同,同时 W e W_e We 和 W h W_h Wh 的作用也相同。
L p r e d = C E ( z T / t , z S / t ) L_{pred} = CE(z^T/t, z^S/t) Lpred=CE(zT/t,zS/t)其中, z S , z T z^S, \quad z^T zS,zT 分别是学生网络和教师网络预测的logits向量, C E CE CE 代表交叉熵损失, t t t 是temperature value,当 t = 1 t=1 t=1 时,表现良好。
对上述三个部分的loss函数进行整合,则可以得到教师网络和学生网络之间对应层的蒸馏损失如下: L l a y e r = { L e m b d , m = 0 L h i d n + L a t t n , M ≥ m > 0 L p r e d , m = M + 1 L_{layer}=\begin{cases} L_{embd}, & m=0 \\ L_{hidn} + L_{attn}, & M \geq m > 0 \\ L_{pred}, & m = M + 1 \end{cases} Llayer=⎩ ⎨ ⎧Lembd,Lhidn+Lattn,Lpred,m=0M≥m>0m=M+1
论文地址:https://arxiv.org/abs/1909.10351
代码地址:https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT