模型蒸馏,Bert模型蒸馏之distilbert

因为现在用的模型越来越大,所以出现了模型蒸馏,模型蒸馏出现的意义在于,用更少的参数,继承模型里面的效果,现在用的模型蒸馏常用teacher-student模型的方式进行蒸馏,整个理念就是让teacher模型学习到模型的大参数,让student模型继承它。

Distill的意思是蒸馏,我们可以从字面上猜测,我们要从一个很大的模型,蒸馏成比较小的模型,也可以用一种角度想,我们让大的模型当作小的模型的老师,而小模型这个学生,只会尽可能的学老师的每个输出。

Bert是12层transformer encode,Distilled BERT是6层transformer encode,Distilled BERT是没有进行自己的预训练,而是将bert的部分参数直接加载到Distilled BERT结构中作为初始化。

很多先前的工作都研究了使用蒸馏来构建特定于任务的模型,但我们在预训练阶段利用了知识蒸馏,并表明有可能将ERT模型的大小减少40%,同时保留其97%的语言理解能力,并将速度提高60%。为了利用较大模型在预训练期间学习的归纳偏差,我们引入了结合语言建模、蒸馏和余弦距离损失的三重损失。我们的模型更小、更快、更轻,预训练成本更低,我们在概念验证实验和比较设备上研究中展示了它在设备上计算的能力。

大致上Distilled BERT的思想就是这样简单,根据作者的实验数据,DistilBERT的参数大约只有BERT的40%,而速度快了60%,并且保持了一定精度。

注意那个温度T的概念,其实就是除,他会把所有logit都除于T,可以想象T越大,大家都除了他之后,大家的“差别”就越小,这样一来,模型就不止能学到对的那个标签,其他错误的标签也学到,一样的道理,T如果是1,就是之前的普通情况,T越小,对的标签越大(例如一般是99%)别的都1%这样,那模型就学不到错误的标签了,所有蒸馏通过把T变大,让学生模型学到所有信息不仅是对的还有错的吗,上图就是这样,它通过soft的标签,通过T,得到一个蒸馏的损失,在让T=1,得到一个正常的损失和预测,从而完成蒸馏模型loss的联动。

distilbert就差不多是这样的,一个基本的蒸馏 用了6层tranformer,

 it is possible to reduce the size of a BERT model by 40%, while retaining 97% of its language understanding capabilities and being 60% faster. 

  • DistilBERT doesn’t have token_type_ids, you don’t need to indicate which token belongs to which segment. Just separate your segments with the separation token tokenizer.sep_token (or [SEP]).
  • DistilBERT doesn’t have options to select the input positions (position_ids input). This could be added if necessary though, just let us know if you need this option.

模型压缩除了蒸馏,还要量化和剪枝,共三个常见操作:

量化简单来说是把数据的bit变低,例如本来32bit存储,现在改成16这种,类似于图片的降低分辨率

剪枝就是把一些没用的参数去掉,例如设一个阈值,低于0.1的参数都设为0这种(当然有更复杂的,注意力剪枝等等)。

基础思想就上面,遇到了细讲。

你可能感兴趣的:(NLP算法遨游之路,bert,自然语言处理,深度学习)