【DL】模型蒸馏Distillation

【DL】模型蒸馏Distillation_第1张图片
过去一直follow着transformer系列模型的进展,从BERT到GPT2再到XLNet。然而随着模型体积增大,线上性能也越来越差,所以决定开一条新线,开始follow模型压缩之模型蒸馏的故事线。

Hinton在NIPS2014[1]提出了知识蒸馏(Knowledge Distillation)的概念,旨在把一个大模型或者多个模型ensemble学到的知识迁移到另一个轻量级单模型上,方便部署。简单的说就是用新的小模型去学习大模型的预测结果,改变一下目标函数。听起来是不难,但在实践中小模型真的能拟合那么好吗?所以还是要多看看别人家的实验,掌握一些trick。

0. 名词解释

teacher - 原始模型或模型ensemble
student - 新模型
transfer set - 用来迁移teacher知识、训练student的数据集合
soft target - teacher输出的预测结果(一般是softmax之后的概率)
hard target - 样本原本的标签
temperature - 蒸馏目标函数中的超参数
born-again network - 蒸馏的一种,指student和teacher的结构和尺寸完全一样
teacher annealing - 防止student的表现被teacher限制,在蒸馏时逐渐减少soft targets的权重

1. 基本思想

1.1 为什么蒸馏可以work
好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的目标是让student学习到teacher的泛化能力,理论上得到的结果会比单纯拟合训练数据的student要好。另外,对于分类任务,如果soft targets的熵比hard targets高,那显然student会学习到更多的信息。

1.2 蒸馏时的softmax
[公式]
在这里插入图片描述
比之前的softmax多了一个参数T(temperature),T越大产生的概率分布越平滑。

有两种蒸馏的目标函数:

1.只使用soft targets:在蒸馏时teacher使用新的softmax产生soft targets;student使用新的softmax在transfer set上学习,和teacher使用相同的T。
2.同时使用sotf和hard targets:student的目标函数是hard target和soft target目标函数的加权平均,使用hard target时T=1,soft target时T和teacher的一样。Hinton的经验是给hard target的权重小一点。另外要注意的是,因为在求梯度(导数)时新的目标函数会导致梯度是以前的在这里插入图片描述,所以要再乘上 在这里插入图片描述,不然T变了的话hard target不减小(T=1),但soft target会变。
直接用logits的MSE(是1的special case)

2. 蒸馏经验

2.1 Transfer Set和Soft target
实验证实,Soft target可以起到正则化的作用(不用soft target的时候需要early stopping,用soft target后稳定收敛)
数据过少的话无法完整表达teacher学到的知识,需要增加无监督数据(用teacher的预测作为标签)或进行数据增强,可以使用的方法有:1.增加[MASK],2.用相同POS标签的词替换,2.随机n-gram采样,具体步骤参考文献2
2.2 超参数T
T越大越能学到teacher模型的泛化信息。比如MNIST在对2的手写图片分类时,可能给2分配0.9的置信度,3是1e-6,7是1e-9,从这个分布可以看出2和3有一定的相似度,因此这种时候可以调大T,让概率分布更平滑,展示teacher更多的泛化能力
T可以尝试1~20之间
2.3 BERT蒸馏
蒸馏单BERT[2]:模型架构:单层BiLSTM;目标函数:logits的MSE
蒸馏Ensemble BERT[3]:模型架构:BERT;目标函数:soft prob+hard prob;方法:MT-DNN。该论文用给每个任务训练多个MT-DNN,取soft target的平均,最后再训一个MT-DNN,效果比纯BERT好3.2%。但感觉该研究应该是刷榜的结晶,平常应该没人去训BERT ensemble吧。。
BAM[4]:Born-aging Multi-task。用多个任务的Single BERT,蒸馏MT BERT;目标函数:多任务loss的和;方法:在mini-batch中打乱多个任务的数据,任务采样概率为 [公式] ,防止某个任务数据过多dominate模型、teacher annealing、layerwise-learning-rate,LR由输出层到输出层递减,因为前面的层需要学习到general features。最终student在大部分任务上超过teacher,而且上面提到的tricks也提供了不少帮助。文献4还不错,推荐阅读一下。
TinyBERT[5]:截止201910的SOTA。利用Two-stage方法,分别对预训练阶段和精调阶段的BERT进行蒸馏,并且不同层都设计了损失函数。与其他模型的对比如下:
【DL】模型蒸馏Distillation_第2张图片

3. 总结

再重点强调一下,student学习的是teacher的泛化能力,而不是“过拟合训练数据”。

目前读的论文不是很多,但个人感觉还是有些炼丹的感觉。文献2中的BERT蒸馏任务,虽然比无蒸馏条件下有将近5个点的提升,但作者没有研究到底是因为数据增多还是蒸馏带来的提升。而且仍然距BERT有很大的距离,虽然速度提升了,但效果并不能上线。文献3中虽然有了比BERT更好的效果,但并没有用轻量的结构,性能还是不变。

接下来我会花时间读更多的论文,写新文章或把tricks加进这篇文章里,同学们有好的经验也可以说一下。

补充一些资源,还没仔细看:

dkozlov/awesome-knowledge-distillation
Distilling BERT Models with spaCy
DistilBERT
Multilingual MiniBERT: Tsai et al. (EMNLP 2019)

你可能感兴趣的:(算法)