BERT 模型蒸馏 TinyBERT

BERT 的效果好,但是模型太大且速度慢,因此需要有一些模型压缩的方法。TinyBERT 是一种对 BERT 压缩后的模型,由华中科技和华为的研究人员提出。TinyBERT 主要用了模型蒸馏的方法进行压缩,在 GLUE 实验中可以保留 BERT-base 96% 的性能,但体积比 BERT 小了 7 倍,速度快了 9 倍。

1.前言

模型蒸馏 Distillation 是一种常用的模型压缩方法,首先训练一个大的 teacher 模型,然后使用 teacher 模型输出的预测值训练小的 student 模型。student 模型学习 teacher 模型的预测结果 (概率值) 从而学习到 teacher 模型的泛化能力。

有不少利用模型蒸馏压缩 BERT 的研究,例如将 BERT 蒸馏到 BiLSTM,还有 huggingface 的 DistilBERT 等,在之前的文章《BERT 模型蒸馏 Distillation BERT》有比较详细的介绍,对模型蒸馏不熟悉的童鞋可以参考一下。

这里主要介绍另一种 BERT 蒸馏的模型 TinyBERT,之前蒸馏模型的损失函数主要是针对 teacher 模型输出的预测概率值,而 TinyBERT 的损失函数包括四个部分:Embedding 层的损失,Transformer 层 attention 的损失,Transformer 层 hidden state 的损失和最后预测层的损失。即 student 模型不仅仅学习 teacher 模型的预测概率,也学习其 Embedding 层和 Transformer 层的特性。

TinyBERT 结构

上面的图片展示了 TinyBERT (studet) 和 BERT (teacher) 的结构,可以看到 TinyBERT 减少了 BERT 的层数,并且减小了 BERT 隐藏层的维度。

2.TinyBERT

TinyBERT 蒸馏过程中的损失函数主要包含以下四个:

  • Embedding 层损失函数
  • Transformer 层 attention 损失函数
  • Transformer 层 hidden state 损失函数
  • 预测层损失函数

我们先看一下 TinyBERT 蒸馏时候每一层的映射方法。

2.1 TinyBERT 蒸馏的映射方法

假设 TinyBERT 有 M 个 Transformer 层,而 BERT 有 N 个 Transformer 层。TinyBERT 蒸馏主要涉及的层有 embedding 层 (编号为0)、Transformer 层 (编号为1到M) 和输出层 (编号 M+1)。

我们需要将 TinyBERT 每一层和 BERT 中要学习的层对应起来,然后再蒸馏。对应的函数为 g(m) = n,m 是 TinyBERT 层的编号,n 是 BERT 层的编号。

对于 embedding 层,TinyBERT 蒸馏的时候 embedding 层 (0) 对应了 BERT 的 embedding 层 (0),即 g(0) = 0。

对于输出层,TinyBERT 的输出层 (M+1) 对应了 BERT 的输出层 (N+1),即 g(M+1) = N+1。

对于中间的 Transformer 层,TinyBERT 采用 k 层蒸馏的方法,即 g(m) = m × N / M。例如 TinyBERT 有 4 层 Transformer,BERT 有 12 层 Transformer,则 TinyBERT 第 1 层 Transformer 学习的是 BERT 的第 3 层;而TinyBERT 第 2 层学习 BERT 的第 6 层。

2.2 Embedding 层损失函数

Embedding 层损失函数

ES 是 TinyBERT 的 embedding,ET 是 BERT 的 embedding,l 是句子序列的长度,而 d‘ 是 TinyBERT embedding 维度,d 是 BERT embedding 维度。因为是要压缩 BERT 模型,所以 d' < d,TinyBERT 希望模型学到的 embedding 与 BERT 原来的 embedding 具有相似的语义,因此采用了上面的损失函数,减少两者 embedding 的差异。

embedding 维度不同,不能直接计算 loss,因此 TInyBERT 增加了一个映射矩阵 We (d'×d) 的矩阵,ES 乘以映射矩阵后维度与 ET 一样。embedding loss 就是二者的均方误差 MSE。

2.3 Transformer 层 attention 损失函数

TinyBERT 在 Transformer 层损失函数有两个,第一个是 attention loss,如下图所示。

TinyBERT attention loss

attention loss 主要是希望 TinyBERT Multi-Head Attention 部分输出的 attention score 矩阵 能够接近 BERT 的 attention score 矩阵。因为有研究发现 BERT 学习到的 attention score 矩阵能够包含语义知识,例如语法和相互关系等,具体可参考论文《What Does BERT Look At? An Analysis of BERT’s Attention》。TinyBERT 通过下面的损失函数学习 BERT attention 的功能,h 表示 Multi-Head Attention 中 head 的个数。

TinyBERT attention loss

2.4 Transformer 层 hidden state 损失函数

TinyBERT 在 Transformer 层的第二个损失函数是 hidden loss,如下图所示。

TinyBERT hidden 层蒸馏

hidden state loss 和 embedding loss 类似,计算公式如下,也需要经过一个映射矩阵。

TinyBERT hidden loss

2.5 预测层损失函数

预测层的损失函数采用了交叉熵,计算公式如下,其中 t 是模型蒸馏的 temperature value,zT 是 BERT 的预测概率,而 zS 是 TinyBERT 的预测概率。

预测层的损失函数

3.TinyBERT 两阶段训练方法

TinyBERT 两阶段训练方法

BERT 有两个训练阶段,第一个训练阶段是训练一个预训练模型 (预训练 BERT),第二个训练阶段是针对具体下游任务微调 (微调后的 BERT)。TinyBERT 的蒸馏过程也分为两个阶段。

  • 第一个阶段,TinyBERT 在大规模的 general 数据集上,利用预训练 BERT 蒸馏出一个 General TinyBERT。
  • 第二个阶段,TinyBERT 采用数据增强,利用微调后的 BERT 训练一个 task specific 模型。

4.参考文献

TINYBERT: DISTILLING BERT FOR NATURAL LANGUAGE UNDERSTANDING

你可能感兴趣的:(BERT 模型蒸馏 TinyBERT)