DistilBERT 论文笔记

单位:HuggingFace
时间:2020.5
发表:NIPS2019
论文链接:https://arxiv.org/pdf/1910.01108.pdf

一、背景

1. 什么是distill(蒸馏)?

蒸馏简单的说是将大模型(teacher)的学习结果,作为小模型(student)的学习目标,意在小模型能学习到大模型的表示。

蒸馏这个方法的核心思想是:好模型的目标不是拟合训练数据而是学习如何泛化到新的数据

所以蒸馏的目标是让学生模型学习到教师模型的泛化能力,理论上得到的结果会比单纯拟合训练数据的学生模型要好。

2. BERT有哪些短板?

从应用落地的角度来说,bert虽然效果好,但有一个短板就是预训练模型太大,预测时间在平均在300ms以上(一条数据),无法满足线上并发量要求高的业务需求。

二、DistilBERT, a distilled version of BERT

1. 作者的思路

之前的模型蒸馏本质上都是两个loss,即distillation loss和student loss

这样模型学到的都是精调后的知识,即模型都是任务相关的,作者想蒸馏出一个任务无关的BERT,这样通用性更强,在具体任务时做具体的精调即可。

2. 具体做法

I. 模型结构

教师模型采用预训练好的BERT-base,学生模型则是6层的transformer。

II. 学生模型初始化方法

采用了BERT-PKD提出的PKD-skip的方式进行初始化,即用BERT-base的第[2,4,6,8,10]层的参数作为学生模型的参数。

DistilBERT 论文笔记_第1张图片

III. Loss的设计

损失函数最终有三个,具体为:

  • MLM loss

    # lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
    # (batch_size*seq_length, vocab_size), (batch_size*seq_length, )
    loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), lm_labels.view(-1))
    # alpha_mlm (default 0.5, here 2.0): Linear weight for the MLM loss
    loss += self.alpha_mlm * loss_mlm
    
    • 约等于bert 的masked language model的损失函数,对应为以前具体任务蒸馏的student loss

    • 用 student 的 logits 和 lm_labels 计算 Lmlm(交叉熵),并计算累计 loss:loss += alpha_mlm × Lmlm

  • CE loss

    # temperature = 2.0,推理时为 1.0
    # ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
    # 定义的损失函数:ti * log(si),计算 KL 散度(衡量两个分步的差异)
    loss_ce = self.ce_loss_fct(
        F.log_softmax(s_logits_slct / self.temperature, dim=-1),
        F.softmax(t_logits_slct / self.temperature, dim=-1)) * (self.temperature) ** 2
    # alpha_ce (default 0.5, here 5.0): Linear weight for the distillation loss.
    loss = self.alpha_ce * loss_ce
    
    • 和教师-学生最后一层的KL散度,对应为以前具体任务蒸馏的distillation loss

    • 计算 mask 的 logits,mask 可以选择只计算 masked tokens,也可以选择计算不含 padding 的 input tokens,两者最后用来计算 loss 的 logits 不相同,其中前者的 size 是 (n_tgt, vocab_size),后者的 size 是 (sum(lenghts), vocab_size)。计算 Lce(散度),并计算 loss:loss = alpha_ce × Lce

  • Cos loss

    # cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")
    loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)
    # alpha_cos (default 0.0, here 1.0): Linear weight of the cosine embedding loss.
    loss += self.alpha_cos * loss_cos
    
    • 计算teacher hidden state和student hidden state的余弦相似度,它将倾向于对齐学生和教师隐藏状态向量的方向。
    • 计算 mask 的 hidden_states(最后一层),mask 选择不含 padding 的 input tokens(同上面第二种,也是输入的 attention mask),size 为 (sum(lenghts), hidden_dim)。计算 Lcos(余弦嵌入),并计算累计 loss:loss += alpha_cos × Lcos

IV. 一些小trick

  1. 采用了RoBERTa的优化策略,动态mask,增大batch size,取消NSP任务的损失函数
  2. mask 时采用了 token_probs,让选择 mask 时更加关注低频词,进而实现对 mask 的平滑取样(如果按平均分布取样的话,取到的 mask 可能大部分都是重复的高频词)

3. 实验

I. 对比实验

三张图分别是:

  • 模型在GLUE上的效果
  • 在下游任务的表现
  • 参数大小和推理速度的比较

DistilBERT 论文笔记_第2张图片

可以看出对比bert-base,精度只下降了1-2个点,推理速度和参数量大小有成倍的提升。不过值得注意的是下游任务同样用到了蒸馏方法。

II. 消融实验

从消融实验可以看出,MLM loss对于学生模型的表现影响较小,同时初始化也是影响效果的重要因素

DistilBERT 论文笔记_第3张图片

4. 总结

DistilBERT利用知识蒸馏的技术,达到不过分降低性能的情况下,减少的模型大小和加快推理的速度

三、自己的思考

  1. 模型Loss融合时没有简单的直接相加,而是加了几个超参数线性相加,设计模型Loss时值得学习。
  2. 可以把bert里所有参数都加上Loss学习起来,应该会有效果,不过操作可能比较复杂,不太美观。
  3. 在hard labels中只有单纯的0和1标签,这实际是不太科学的,每个分类之间也有他们的相似程度,这也是temperature超参数的设计思想和模型蒸馏能work的主要原因。

你可能感兴趣的:(深度学习,人工智能,自然语言处理)