【Lifelong learning】Lifelong Language Knowledge Distillation

链接:http://arxiv.org/abs/2010.02123

简介

Lifelong Language Knowledge Distillation终身语言知识提炼,是一种利用知识蒸馏的终身学习方法。
其主要思想是:每次遇到新任务时,不直接让model去学习,而是先在任务上训练一个teacher model,然后运用知识蒸馏技术,将知识传递给model。

  • 知识蒸馏:有两个模型: student model(小)和teacher model(大)。student model需要通过训练,模仿teacher model的行为并使得两者性能相近。

本文将知识蒸馏的思想运用到了终身学习的语言领域。但不同之处在于: L2KD的student model和teacher model是一样大的。

如下图所示。这种方法只需要为每个新任务多花一点时间训练一个一次性teacher model,在学习下一个任务时可以丢弃该模型;因此,L2KD不需要额外的内存或模型容量,这使得提出的模型在实际使用中更有效。【Lifelong learning】Lifelong Language Knowledge Distillation_第1张图片

必须要指出的是:L2KD作为一种方法而非具体模型,可以加到大部分LLL模型上去。
因此,本文就将L2KD加到了LAMOL上去。
LAMOL介绍:https://blog.csdn.net/Baigker/article/details/121650749?spm=1001.2014.3001.5501

Proposed Approach

正如在简介中提到的,L2DK本质是一种知识蒸馏,并且在实际运用中要加到其他模型上去。因此本文也遵循这一顺序,即:先介绍LAMOL,再介绍知识蒸馏,最后才说明L2KD的原理。

LAMOL

在LAMOL的setting中,语言数据集中的所有样本都有三个部分:上下文、问题和答案。我们可以简单地将这三个部分连接成一个句子,并训练模型根据上下文和前面的问题生成答案,如图所示。【Lifelong learning】Lifelong Language Knowledge Distillation_第2张图片
除了生成给定问题的答案外,该模型同时学习建模整个训练样本。通过这样做,在训练下一个任务时,模型可以为前一个任务生成训练样本,同时训练新任务的数据和为前一个任务生成的伪数据。因此,模型在适应新任务时忘记的更少。

知识蒸馏

语言模型

一般来说,语言模型的目标是使预测下一个词时的负对数似然(NLL)最小化:在这里插入图片描述
而在知识蒸馏中,我们将student model和teacher model之间的预测误差最小化。考虑错误的目标单元可以在单词级或序列级进行。

Word-Level (Word-KD)

在预测下一个词时,我们最小化student和teacher的输出分布之间的交叉熵:
【Lifelong learning】Lifelong Language Knowledge Distillation_第3张图片

Sequence-Level (Seq-KD)

我们将teacher model中的贪心解码或beam search输出序列 x ^ \hat x x^作为硬目标直接最小化负对数似然,就像普通语言建模一样:在这里插入图片描述

Soft Sequence-Level (Seq-KDsoft)

我们进一步研究软目标加上teacher解码序列是否对模型更有帮助,因此我们进行 S e q − K D s o f t Seq-KD_{soft} SeqKDsoft,对teacher model的贪心解码或beam search输出进行Word-KD。 S e q − K D s o f t Seq-KD_{soft} SeqKDsoft和Word-KD之间的唯一区别是Word-KD的输入 x < t xx<t现在被替换为 x ^ < t \hat xx^<t,teacher model的输出序列:【Lifelong learning】Lifelong Language Knowledge Distillation_第4张图片
注意,无论我们在知识蒸馏中使用何种损失函数,teacher model总是固定的。因此,LLL模型求参数 θ S ∗ θ^*_S θS的优化过程可以写成:在这里插入图片描述
知识蒸馏可以应用于最小化LM和QA在LAMOL中的损失。假设有一个任务流的数据集 { D 1 , D 2 , … } \{ D_1, D_2,…\} { D1,D2},我们的LLL模型从 D 1 D_1 D1学习到 D m − 1 D_{m-1} Dm1,现在适用于 D m D_m Dm。首先,我们通过最小化LAMOL中LM和QA的负对数似然损失来训练 D m D_m Dm的教师模型,并获得模型参数 θ m T θ_m^T θmT
现在我们的LLL模型(参数 θ S θ_S θS)可以通过从教师模型中提取知识来训练 D m D_m Dm。给定一个训练样本 X i m = { x 1 , x 2 , … , x T } ∈ D m X^m_i = \{ x_1, x_2,…, x_T\} ∈D_m Xim={ x1,x2xT}Dm(包括上下文、问题和答案),我们将其最小化:【Lifelong learning】Lifelong Language Knowledge Distillation_第5张图片
其中 a 1 a_1 a1表示答案的起始位置。这里我们以Word-KD为例,但我们也可以将答案部分的文本替换为教师生成的答案,从而进行 S e q − K D s o f t Seq-KD_{soft} SeqKDsoft或Seq-KD。
LLL模型除了对来自 D m D_m Dm的样本进行训练外,还会为之前的任务生成伪数据 D p r e v D_{prev} Dprev。然而,对于 D p r e v D_{prev} Dprev中的样本,我们不能在这里进行知识蒸馏,因为在我们的设置中,之前任务的教师模型在适应下一个任务后将被丢弃。因此,给定生成的数据 X i p r e v ∈ D p r e v X^{prev}_i∈D_{prev} XiprevDprev,我们在这里只最小化NLL损失:【Lifelong learning】Lifelong Language Knowledge Distillation_第6张图片
最后,我们共同优化了两种损失,得到了LLL模型的参数 θ S ∗ θ^∗_S θS:在这里插入图片描述

整体算法流程:
【Lifelong learning】Lifelong Language Knowledge Distillation_第7张图片

你可能感兴趣的:(论文阅读,深度学习,机器学习)