TinyBert解读

1. 背景

随着NLP语言模型的发展,模型参数越来越多,计算流程越来越复杂,很难将BERT应用到一些资源有限的设备上。为了提升模型的计算速度,提出了tinybert,模型预测提升了9.4倍的速度,模型大小缩小了7.5倍. 具体论文见 《TinyBERT: Distilling BERT for Natural Language Understanding》

1.1 模型结构变换

tinybert的层数相对bert base从12层降低到4层;FFN层输出的大小从3072降低到1200,Head个数12保持不变,hiddent size从768降低至312;最终参数量从110M降低到14.5M

2. Transformer distillation

论文为transformer设计了一套专门的蒸馏方式。

首先,对于n层的teacher bert,设计了一个mapping function n = g ( m ) n = g(m) n=g(m)j将student bert的第m层映射为原来的teacher的第n层(论文中teacher 为12层,student为4层,因此student第1层对应原来的teacher的第3层,以此类推),特别的当m=0时,对应bert输入的embedding layer。

2-1. Attention based distillation

受到attention weights携带了大量语言和语法信息的启发,因此学生网络会学习teacher transformer layer中的multi-head attention.
TinyBert解读_第1张图片
这里h为multi-head attention的head的个数,计算student attention矩阵和teacher attention矩阵权重的MSE loss 进行蒸馏学习

2-2. Hidden states based distillation

TinyBert解读_第2张图片
这里是对teacher bert的层间蒸馏,因为bert base的hidden states的维度为768,student 的hidden states的维度为128,因此引入 W h W_h Wh进行不同维度的映射(通过学习得到映射矩阵),对最终的transform的输出结果的MSE loss进行蒸馏学习

2-3. Embedding-layer distillation

在这里插入图片描述
对bert的输入embedding矩阵进行蒸馏学习,原理同以上

2-4. Prediction-layer distillation

仿照Hinton老爷子的蒸馏方式,除了前面的层间蒸馏,对最终输出层的结果的softmax带温度进行蒸馏学习
TinyBert解读_第3张图片
综上,对bert的transformer的最终蒸馏为一下,对于输入层m=0,只用embeding的MSE蒸馏loss;对0-M中间层的蒸馏,采用attention的 MSE loss和hidden states的MSE loss;对最终输出层,采用带温度的softmax 的交叉熵loss
TinyBert解读_第4张图片

3. TinyBERT Learning

作者对bert的蒸馏总结为以下两步
TinyBert解读_第5张图片

3.1 general distillation

在原始teacher bert(没有微调)上用通用预料蒸馏tinybert(方式采用上面的transformer distillation)

3.3 task-specific distillation

对teacher bert进行fine-tune后作为新的teacher,用任务相关的数据对上一步得到的tinybert进行蒸馏得到student tinybert(即以上一步3.1中的tinybert的参数初始化进行蒸馏);方式也是采用transformer distillation。

这里作者在task-specific的预料中还做了data-augmentation,具体做法如下:
TinyBert解读_第6张图片
数据增强的方式:即对bert分词结果,如果是一个完整的词,则加上mask,对teacher bert在mask对于节点的输出上,取topK个最大概率可能的词进入候选集合;否则用GloVe的词向量找topK个cosine相似度最高的词进入候选集合。最后随机候选集合中的词,以阈值p为概率决定是否替换原词。作者通过这种方式增加新的训练数据。

总结

作者引入了一种针对bert transformer专门的蒸馏方式,并且提供了一种有效分为两步的bert蒸馏框架,达到了当时SOTA的效果。

你可能感兴趣的:(NLP,深度学习,人工智能,nlp,机器学习,自然语言处理)