英文题目:Meta-KD: A Meta Knowledge Distillation Framework for Language Model Compression across Domains
中文题目:Meta-KD:跨领域语言模型压缩的元知识蒸馏框架
论文地址:http://export.arxiv.org/pdf/2012.01266v1.pdf
领域:自然语言处理, 知识蒸馏
发表时间:2020.12
作者:Haojie Pan,阿里团队
出处:ACL
被引量:1
代码和数据:https://github.com/alibaba/EasyNLP(集成于EasyNLP)
阅读时间:2022-09-17
结合元学习和蒸馏学习:元学习使得模型获取调整超参数的能力,使其可以在已有知识的基础上快速学习新任务。
预训练的自然语言模型虽然效果好,但占空间大,预测时间长,使模型不能应用于实时预测任务。典型的方法是使用基于老师/学生模型的知识蒸馏。而模型一般面向单一领域,忽略了不同领域知识的知识转移。本文提出元蒸馏算法,致力于基于元学习的理论,让老师模型具有更大的转移能力,尤其对few-shot和zero-shot任务效果更好。
如图-1所示,一个学物理的学生如果跟数学老师学习了数学方程知识,可能有助于他更好地理解物理方程。相近领域的数据可能提升模型的能力,但其它领域模型也可能转移一些无关的知识,从而影响性能。另外,当前研究证明:使用多任务精调也未必能提升所有任务的性能。由此,文中提出需要让老师模型消化不同领域的知识,并可针对具体领域,将知识转移到学生模型。在图-1©中,如果有万能的科学老师(元学习),它既会数学也会物理,则可以更好地教导学生。
如图-2所示,模型包含两部分:元老师和元蒸馏:
首先利用多领域数据集训练元老师,通过引入破坏域损失来获取跨域知识,然后针对具体领域,用领域相关数据集引导元老师,以提升学生的蒸馏能力。
文章贡献
定义:设有K个领域的K个数据集参与训练,D为数据集,M为大模型,S为蒸馏后的学习模型。
模型训练分为两个场景:
将BERT模型作为基础模型。
基于原型实例加权
学习过程中对每个实例X计算原型得分t,假设处理分类问题,共m个类别,计算所有第K领域中实例属于每个类别的概率均值(请参考图-3左侧的实心多边形):
计算原型得分如下:
此处cos用于计算相似度,α是超参数,公式的前半部分计算了该实体与它所在的领域的关系(在嵌入空间与同类实体的一致性),后半部分计算了与其它领域的关系。这样模型就同时学习了同一领域的知识和其它领域的知识。
域破坏
除了交叉熵损失,还加入了域破坏损失以提升元老师转移学习的能力。对于每个实例,学习一个与h维度相同的域嵌入,记作ED(epsilon D)。
在BERT以外,又加入了一个子网络,对网络输出进一步处理:
针对域破坏的损失函数定义为:
其中σ(sigma)表示域类别,它是一个指示函数,只有0/1两个取值,这里最大化元教师对域标签做出错误预测的可能性。
我理解,这里的损失函数是让实例最终能识别它所在的域类别k。
损失函数
最终的损失定义为:使用得分t加权针对所有领域的交叉损失;同时,加入了域破坏损失作为辅助,以训练模型转移知识的能力。
这里的γ1(gamma)是超参数,用于设定域破坏损失的贡献。
使用小型的BERT作为学生模型,蒸馏网络结构如图-3所示:
目标由五个部分组成:输入嵌入Lembd,隐藏层状态Lhidn,注意力矩阵Lattn,输出ligit和知识转移。其中Lembd,Lhidn,Lattn的蒸馏方法与TinyBERT一样。又加入了Lpred对输出层使用软交叉熵损失。
另外,考虑到特定领域的知识转移,下面公式又加入了域相关的损失:
以此鼓励学生模型学习更多的该领域相关知识。我理解这里的hM是指对该领域的老师模型获得的编码。
又引入λk参数,它是领域相关的权重:
其中y^是预测的类别标签,当预测准确,或者t比较大时,λ值也相应变大,它反应的是老师在特定任务上监督学生的能力。
整体蒸馏损失计算方法如下:
使用自然语言推理(MNLI)和情绪分析(Amazon Reviews)两个任务评价模型。
表-2和3展示了主实验结果:
得出三个结论:
图-4也说明在few-shot情况下,实例越少,Meta-KD效果越明显: