NLP-预训练模型-2019-NLU:DistilBERT【 BERT模型压缩】【模型大小减小了40%(66M),推断速度提升了60%,但性能只降低了约3%】

《原始论文:DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter》

NLP预训练模型随着近几年的发展,参数量越来越大,受限于算力,在实际落地上线带来了困难,针对最近最为流行的BERT预训练模型,提出了DistilBert,在保留97%的性能的前提下,模型大小下降40%,inference运算速度快了60%。

NLP-预训练模型-2019-NLU:DistilBERT【 BERT模型压缩】【模型大小减小了40%(66M),推断速度提升了60%,但性能只降低了约3%】_第1张图片
Distill的意思是蒸馏,我们可以从字面上猜测,我们要从一个很大的模型,蒸馏成比较小的模型,也可以用一种角度想,我们让大的模型当作小的模型的老师,而小模型这个学生,只会尽可能的学老师的每个输出。最早提出于Hinton大佬的论文。

一、模型蒸馏原理

Hinton在NIPS2014《Distilling the Knowledge in a Neural Network》 提出了知识蒸馏(Knowledge Distillation)的概念,旨在把一个大模型或者多个模型ensemble学到的知识迁移到另一个轻量级单模型上,方便部署。

简单的说就是用小模型去学习大模型的预测结果,而不是直接学习训练集中的label。

在蒸馏的过程中,我们将原始大模型称为教师模型(teacher),新的小模型称为学生模型(student),训练集中的标签称为hard label,教师模型预测的概率输出为soft label,temperature(T)是用来调整soft label的超参数。

蒸馏这个概念之所以work,核心思想是因为好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的目标是让学生模型学习到教师模型的泛化能力,理论上得到的结果会比单纯拟合训练数据的学生模型要好。

1、如何蒸馏

蒸馏发展到今天,有各种各样的花式方法,我们先从最基本的说起。

之前提到学生模型需要通过教师模型的输出学习泛化能力,那对于简单的二分类任务来说,直接拿教师预测的0/1结果会与训练集差不多,没什么意义,那拿概率值是不是好一些?于是Hinton采用了教师模型的输出概率q,同时为了更好地控制输出概率的平滑程度,给教师模型的softmax中加了一个参数T。

有了教师模型的输出后,学生模型的目标就是尽可能拟合教师模型的输出,新loss就变成了:

其中CE是交叉熵(Cross-Entropy),y是真实label,p是学生模型的预测结果, 是蒸馏loss的权重。这里要注意的是,因为学生模型要拟合教师模型的分布,所以在求p时的也要使用一样的参数T。

另外,因为在求梯度时新的目标函数会导致梯度是以前的 1 T 2 \frac{1}{T^2} T21 ,所以要再乘上 T 2 T^2 T2,不然 T T T 变了的话hard label不减小(T=1),但soft label会变。

有同学可能会疑惑:如果可以拟合prob,那直接拟合logits可以吗?

当然可以,Hinton在论文中进行了证明,如果T很大,且logits分布的均值为0时,优化概率交叉熵和logits的平方差是等价的。

2、BERT蒸馏

在BERT提出后,如何瘦身就成了一个重要分支。主流的方法主要有剪枝、蒸馏和量化。量化的提升有限,因此免不了采用剪枝+蒸馏的融合方法来获取更好的效果。

接下来将介绍BERT蒸馏的主要发展脉络,从各个研究看来,蒸馏的提升:

  • 一方面来源于从:精调阶段蒸馏->预训练阶段蒸馏
  • 另一方面则来源于蒸馏最后一层知识->蒸馏隐层知识->蒸馏注意力矩阵

在监督学习领域,对于一个分类问题,定义soft label为模型的输出(即不同label的概率), hard label为最终正确的label(也就是ground truth),通常是通过最大化正确label的概率来进行学习的,通常采用 cross-entropy作为损失函数,即让正确label的概率尽可能预测为1,其余label的概率趋近于0,但是这些不正确趋近于0的label也是有大有小的(比把图片数字2识别成3的概率还是要比识别成9大,尽管他们都趋近于0),这被称为"暗知识(Dark Knowledge)", 这也反应了模型的泛化能力。但因为过于趋近0不利于student模型学习,为了让student也容易学习tearcher的输出,引入了带温度T的softmax概率为
NLP-预训练模型-2019-NLU:DistilBERT【 BERT模型压缩】【模型大小减小了40%(66M),推断速度提升了60%,但性能只降低了约3%】_第2张图片
当温度T为1的时候,即为标准的softmax。训练的时候T>1, 方便学到类间信息;预测的时候T=1,恢复到标准的softmax进行计算。T越大,输出的概率约平滑。

具体模型的训练方式如图
NLP-预训练模型-2019-NLU:DistilBERT【 BERT模型压缩】【模型大小减小了40%(66M),推断速度提升了60%,但性能只降低了约3%】_第3张图片
Loss Fn为cross entropy,最终的损失函数为图中两个loss的线性组合。

二、DistillBert

DistillBERT的教师模型采用了预训练好的BERT-base,学生模型则是6层transformer,采用了PKD-skip的方式进行初始化。和之前蒸馏目标不同的是,为了调整教师和学生的隐层向量方向,作者新增了一个cosine embedding loss,蒸馏最后一层hidden的。最终损失函数由MLM loss、教师-学生最后一层的交叉熵、隐层之间的cosine loss组成。

DistillBert是在bert的基础上用知识蒸馏技术训练出来的小型化bert。整体上来说这篇论文还是非常简单的,只是引入了知识蒸馏技术来训练一个小的bert。具体做法如下:

  1. 给定原始的bert-base作为teacher网络。
  2. 在bert-base的基础上将网络层数减半(也就是从原来的12层减少到6层)。
  3. 利用teacher的软标签和teacher的隐层参数来训练student网络。

训练时的损失函数定义为三种损失函数的线性和,三种损失函数分别为:

  1. L c e L_{ce} Lce:这是teacher网络softmax层输出的概率分布和student网络softmax层输出的概率分布的交叉熵(注:MLM任务的输出)。
  2. L m l m L_{mlm} Lmlm。这是student网络softmax层输出的概率分布和真实的one-hot标签的交叉熵
  3. L c o s L_{cos} Lcos。这是student网络隐层输出和teacher网络隐层输出的余弦相似度值,在上面我们说student的网络层数只有6层,teacher网络的层数有12层,因此个人认为这里在计算该损失的时候是用student的第1层对应teacher的第2层,student的第2层对应teacher的第4层,以此类推。

作者对student的初始化也做了些工作,作者用teacher的参数来初始化student的网络参数,做法和上面类似,用teacher的第2层初始化student的第1层,teacher的第4层初始化student的第2层。

作者也解释了为什么减小网络的层数,而不减小隐层大小,作者认为在现代线性代数框架中,在张量计算中,降低最后一维(也就是隐层大小)的维度对计算效率提升不大,反倒是减小层数,也提升计算效率。

另外作者在这里移除了句子向量和pooler层,在这里也没有看到NSP任务的损失函数,因此个人认为作者也去除了NSP任务(主要是很多人证明该任务并没有什么效果)。

整体上来说虽然方法简单,但是效果还是很不错的,模型大小减小了40%(66M),推断速度提升了60%,但性能只降低了约3%。

三、BERT蒸馏技巧

1、剪层还是减维度?

这个选择取决于是预训练蒸馏还是精调蒸馏。

预训练蒸馏的数据比较充分,可以参考MiniLM、MobileBERT或者TinyBERT那样进行剪层+维度缩减,如果想蒸馏中间层,又不想像MobileBERT一样增加bottleneck机制重新训练一个教师模型的话可以参考TinyBERT,在计算隐层loss时增加一个线性变换,扩大学生模型的维度:

对于针对某项任务、只想蒸馏精调后BERT的情况,则推荐进行剪层,同时利用教师模型的层对学生模型进行初始化。从BERT-PKD以及DistillBERT的结论来看,采用skip(每隔n层选一层)的初始化策略会优于只选前k层或后k层。

2、用哪个Loss?

看完原理后相信大家也发现了,基本上每个模型蒸馏都用的是不同的损失函数,CE、KL、MSE、Cos魔幻组合,自己蒸馏时都不知道选哪个好。梳理了一番,大家可以根据自己的任务目标挑选:
NLP-预训练模型-2019-NLU:DistilBERT【 BERT模型压缩】【模型大小减小了40%(66M),推断速度提升了60%,但性能只降低了约3%】_第4张图片
对于hard label,使用KL和CE是一样的,因为,训练集不变时label分布是一定的。但对于soft label则不同了,不过表中不少模型还是采用了CE,只有Distilled BiLSTM发现 M S E ( z t , z s ) MSE(z^t,z^s) MSE(zt,zs) 更好。个人认为可以CE/MSE/KL都试一下,但MSE有个好处是可以避免T的调参。

中间层输出的蒸馏,大多数模型都采用了MSE,只有DistillBERT加入了cosine loss来对齐方向。

注意力矩阵的蒸馏loss则比较统一,如果要蒸馏softmax之前的attention logits可以采用MSE,之后的attention prob可以用KL散度。

3、T 和 α 如何设置?

超参数 α 主要控制soft label和hard label的loss比例,Distilled BiLSTM在实验中发现只使用soft label会得到最好的效果。个人建议让soft label占比更多一些,一方面是强迫学生更多的教师知识,另一方面实验证实soft target可以起到正则化的作用,让学生模型更稳定地收敛。

超参数T主要控制预测分布的平滑程度,TinyBERT实验发现T=1更好,BERT-PKD的搜索空间则是{5, 10, 20}。因此建议在1~20之间多尝试几次,T越大越能学到teacher模型的泛化信息。比如MNIST在对2的手写图片分类时,可能给2分配0.9的置信度,3是1e-6,7是1e-9,从这个分布可以看出2和3有一定的相似度,这种时候可以调大T,让概率分布更平滑,展示teacher更多的泛化能力。

4、需要逐层蒸馏吗?

如果不是特别追求零点几个点的提升,建议无脑一次性蒸馏,从MobileBERT来看这个操作性价比太低了。

四、蒸馏代码实战

目前Pytorch版本的模型蒸馏有一个非常赞的开源工具 TextBrewer,在它的src/textbrewer/losses.py文件下可以看到各种loss的实现。

最后输出层的CE/KL/MSE loss比较简单,只需要将两者的logits除temperature之后正常计算就可以了,以CE为例:

最后输出层的CE/KL/MSE loss比较简单,只需要将两者的logits除temperature之后正常计算就可以了,以CE为例:

def kd_ce_loss(logits_S, logits_T, temperature=1):
    '''
    Calculate the cross entropy between logits_S and logits_T
    :param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
    :param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
    :param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,)
    '''
    if isinstance(temperature, torch.Tensor) and temperature.dim() > 0:
        temperature = temperature.unsqueeze(-1)
    beta_logits_T = logits_T / temperature
    beta_logits_S = logits_S / temperature
    p_T = F.softmax(beta_logits_T, dim=-1)
    loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean()
    return loss

对于hidden MSE的蒸馏loss,则需要去除被mask的部分,另外如果维度不一致,需要额外加一个线性变换,TextBrewer默认输入维度是一致的:

def hid_mse_loss(state_S, state_T, mask=None):
    '''
    * Calculates the mse loss between `state_S` and `state_T`, which are the hidden state of the models.
    * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
    * If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
    :param torch.Tensor state_S: tensor of shape  (*batch_size*, *length*, *hidden_size*)
    :param torch.Tensor state_T: tensor of shape  (*batch_size*, *length*, *hidden_size*)
    :param torch.Tensor mask:    tensor of shape  (*batch_size*, *length*)
    '''
    if mask is None:
        loss = F.mse_loss(state_S, state_T)
    else:
        mask = mask.to(state_S)
        valid_count = mask.sum() * state_S.size(-1)
        loss = (F.mse_loss(state_S, state_T, reduction='none') * mask.unsqueeze(-1)).sum() / valid_count
    return loss

蒸馏attention矩阵则也要考虑mask,但注意这里要处理的维度是N*N:

def att_mse_loss(attention_S, attention_T, mask=None):
    '''
    * Calculates the mse loss between `attention_S` and `attention_T`.
    * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
    :param torch.Tensor logits_S: tensor of shape  (*batch_size*, *num_heads*, *length*, *length*)
    :param torch.Tensor logits_T: tensor of shape  (*batch_size*, *num_heads*, *length*, *length*)
    :param torch.Tensor mask: tensor of shape  (*batch_size*, *length*)
    '''
    if mask is None:
        attention_S_select = torch.where(attention_S <= -1e-3, torch.zeros_like(attention_S), attention_S)
        attention_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), attention_T)
        loss = F.mse_loss(attention_S_select, attention_T_select)
    else:
        mask = mask.to(attention_S).unsqueeze(1).expand(-1, attention_S.size(1), -1) # (bs, num_of_heads, len)
        valid_count = torch.pow(mask.sum(dim=2),2).sum()
        loss = (F.mse_loss(attention_S, attention_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(2)).sum() / valid_count
    return loss

最后是只在DistillBERT中出现的cosine loss,可以直接使用pytorch的默认接口:

def cos_loss(state_S, state_T, mask=None):
    '''
    * Computes the cosine similarity loss between the inputs. This is the loss used in DistilBERT, see `DistilBERT `_
    * If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
    * If the hidden sizes of student and teacher are different, 'proj' option is required in `inetermediate_matches` to match the dimensions.
    :param torch.Tensor state_S: tensor of shape  (*batch_size*, *length*, *hidden_size*)
    :param torch.Tensor state_T: tensor of shape  (*batch_size*, *length*, *hidden_size*)
    :param torch.Tensor mask:    tensor of shape  (*batch_size*, *length*)
    '''
    if mask is  None:
        state_S = state_S.view(-1,state_S.size(-1))
        state_T = state_T.view(-1,state_T.size(-1))
    else:
        mask = mask.to(state_S).unsqueeze(-1).expand_as(state_S).to(mask_dtype) #(bs,len,dim)
        state_S = torch.masked_select(state_S, mask).view(-1, mask.size(-1))  #(bs * select, dim)
        state_T = torch.masked_select(state_T, mask).view(-1, mask.size(-1))  # (bs * select, dim)
 
    target = state_S.new(state_S.size(0)).fill_(1)
    loss = F.cosine_embedding_loss(state_S, state_T, target, reduction='mean')
    return loss



参考资料:
All The Ways You Can Compress BERT
NLP中的预训练语言模型(四)—— 小型化bert(DistillBert, ALBERT, TINYBERT)
BERT模型压缩-DistilBERT(第一篇)
distilbert模型的测试
DistilBert解读
NLP中的预训练语言模型(四)—— 小型化bert(DistillBert, ALBERT, TINYBERT)
【BERT蒸馏】DistilBERT、Distil-LSTM、TinyBERT、FastBERT(论文+代码)
BERT蒸馏完全指南|原理/技巧/代码
模型压缩实践收尾篇——模型蒸馏以及其他一些技巧实践小结

你可能感兴趣的:(#,Bert系列,自然语言处理,bert,深度学习,DistilBert)