连续学习之 -- Overcoming catastrophic forgetting in neural networks

Overcoming catastrophic forgetting in neural networks

      • 0. Continual learning basics
      • 1. 主要内容
        • 1.0 EWC (elastic weight consolidation)
      • References

本篇博客,主要分析 “Overcoming catastrophic forgetting in neural networks” 一文的主要思想, 来了解连续学习中的参数正则这一类方法。这是来自2016的工作,它将continual learning 的方法用到神经网络中,用于减轻神经网络经常碰到的灾难性遗忘现象,特别是在处理序列类型任务时:模型倾向于学习当前任务,而用最新的模型去预测之前的任务,性能大大降低。

0. Continual learning basics

连续学习旨在模仿人类学习的过程,即人在不断学习新知识的同时,不遗忘之前已经学到的旧知识。虽然,随着数据的爆炸性增长以及计算力的大幅增长,深度神经网络在很多领域取得了巨大的成功,例如计算机视觉,NLP等等。当然,目前的深度学习技术也面临很多问题,如模型泛化性,模型可解释性,灾难性遗忘等问题。这篇博客分析的内容主要涉及灾难性遗忘这一问题:

  • 灾难性遗忘:深度神经网络,往往在处理单个任务时,性能良好。但是在处理序列类型任务时,随着任务的增多,模型会倾向于当前任务的学习,而在之前的任务上性能大大下降。

大致上,连续学习方法可以分为如下三类 [2]:

  • Replay-based 方法:这类方法在学习新任务的同时,通过使用之前任务的一些典型样本来克服新训练的模型对之前任务的遗忘。当然这会增加模型的额外存储。

  • Regularization-based 方法:这类方法又可以分为两子类,data-focused 方法(例如知识蒸馏)和 priori-focused 方法(利用模型参数的分布先验,例如估计参数的重要性,如果某些参数对之前的任务重要,那么在学习新任务的过程中应尽量少改变这类参数,因此这类方法也称为参数正则方法)。当然,这个方法的关键就是如何估计参数的重要性。

  • Parameter isolation 方法:这类方法的思想更加直接,即对不同的任务使用不同的网络结构,具体而言,针对不同的任务可以单独设计模型,或者仅仅使用一个模型的部分参数。

本篇博客分析的文章 – “Overcoming catastrophic forgetting in neural networks”, 属于 Regularization-based 这一类中参数正则的方法,它通过计算Fisher信息矩阵,估计参数的重要性,从而在学习新任务的过程中,对之前任务重要的参数尽量少改变。

1. 主要内容

这篇文章[1]的主要贡献:提出了文中所称的 “elastic weight consolidation ” (EWC)算法,利用Fisher信息矩阵,确定对之前任务重要的参数;然后在学习新任务时,通过一参数正则loss的方式,减少对这些重要参数的改变。不失一般性,以连续学习两个任务 T a s k   A , T a s k   B Task ~A,Task ~B Task A,Task B 为例,可类似推广到多个任务的情形。为了阐述的方便,首先对下文中的符号做一些简单的说明:

  • θ \theta θ:表示神经网络 f ( ⋅ ) f(\cdot) f() 的参数,包括所有权重 w w w 和偏置 b b b
  • θ A ∗ \theta_{A}^{*} θA:表示神经网络 f ( ⋅ ) f(\cdot) f() 训练 Task A 结束得到的最优模型参数。
  • θ B ∗ \theta_{B}^{*} θB: 表示在训练 T a s k   A Task ~A Task A 结束之后,接着训练 T a s k   B Task ~ B Task B 得到的最优模型参数。

EWC 的具体内容如下:

1.0 EWC (elastic weight consolidation)

  • Motivation
    E W C EWC EWC 的目的就是,在 T a s k   A Task ~A Task A 训练结束之后,计算 θ A ∗ \theta_{A}^{*} θA 中的参数关于 T a s k   A Task ~A Task A 的重要性。这个重要性指标,根据Fisher信息矩阵得到,然后在训练 T a s k   B Task ~B Task B 时,对这些于 T a s k   A Task ~A Task A 重要的参数,加强限制(也即正则),如下式所示: L ( θ ) = L B ( θ ) + ∑ i λ 2 F i ( θ i − θ A , i ∗ ) 2 (1) \mathcal{L}(\theta)=\mathcal{L}_{B}(\theta) + \sum_{i}\frac{\lambda}{2}F_{i}(\theta_{i} - \theta_{A,i}^*) ^2\tag{1} L(θ)=LB(θ)+i2λFi(θiθA,i)2(1)
    其中 L B ( θ ) \mathcal{L}_{B}(\theta) LB(θ) 表示 T a s k   B Task~B Task B 的 loss, F i ≥ 0 F_i\geq 0 Fi0 是参数 θ A , i ∗ \theta_{A,i}^{*} θA,i 的重要性系数(根据Fisher信息矩阵计算得到,具体见后面的分析), λ \lambda λ 则表示平衡两个loss的超参数。
    根据公式(1),易见:如果某个参数 θ A , i ∗ \theta_{A,i}^* θA,i T a s k   A Task~A Task A 重要,也即 F i F_i Fi 越大, 则在优化 T a s k   B Task~B Task B 时,会限制 θ i \theta_{i} θi θ A , i ∗ \theta_{A,i}^* θA,i 附近,从而减缓新优化的模型对 T a s k   A Task~A Task A 的遗忘,如 Fig.1 所示。
    连续学习之 -- Overcoming catastrophic forgetting in neural networks_第1张图片
  • 计算参数重要性
    原文从概率的角度,来推导重要参数的约束形式。神经网络的训练过程,可视作通过训练数据集 D \mathcal{D} D,优化网络参数 θ \theta θ 来最大化条件概率 p ( θ ∣ D ) p(\theta|\mathcal{D}) p(θD),而根据 B a y e s Bayes Bayes 准则,有: log ⁡ p ( θ ∣ D ) = log ⁡ p ( D ∣ θ ) + log ⁡ p ( θ ) − log ⁡ p ( D ) (2) \log{p(\theta|\mathcal{D})}=\log{p(\mathcal{D}|\theta)} + \log{p(\theta)} - \log{p(\mathcal{D})} \tag{2} logp(θD)=logp(Dθ)+logp(θ)logp(D)(2) 注意到,等式右边第一项 log ⁡ p ( D ∣ θ ) × ( − 1 ) \log{p(\mathcal{D}|\theta)}\times (-1) logp(Dθ)×(1),即为通常的loss函数 L ( θ ) \mathcal{L(\theta)} L(θ)
    如果进一步把数据 D \mathcal{D} D 分为两个相互独立的数据子集 D A , D B \mathcal{D}_{A},\mathcal{D}_B DA,DB, 分别用于 T a s k   A , T a s k   B Task~A, Task~B Task A,Task B。那么由公式(2)可得: log ⁡ p ( θ ∣ D ) = log ⁡ p ( D B ∣ θ ) + log ⁡ p ( θ ∣ D A ) − log ⁡ p ( D B ) (3) \log{p(\theta|\mathcal{D})}=\log{p(\mathcal{D}_{B}|\theta)} + \log{p(\theta|\mathcal{D}_{A})} - \log{p(\mathcal{D}_{B})} \tag{3} logp(θD)=logp(DBθ)+logp(θDA)logp(DB)(3) 因此关于 T a s k   A Task~A Task A 的信息,都集中在等式右边的第二项 log ⁡ p ( θ ∣ D A ) \log{p(\theta|\mathcal{D}_{A})} logp(θDA) 这个后验分布上。
    一般而言, log ⁡ p ( θ ∣ D A ) \log{p(\theta|\mathcal{D}_{A})} logp(θDA) 没有解析解或者非常复杂,很难直接处理。如果,对它进行一些简单近似,则可以得到某些简单形式。如原文中采用 Laplace 近似(参考[3]),假设 p ( θ ∣ D A ) p(\theta|\mathcal{D}_{A}) p(θDA) 服从高斯分布,均值为 θ A ∗ \theta_A^* θA,协方差矩 ∑ \sum 满足:
    ∑ − 1 = [ F 1 0 ⋯ 0 0 F 2 ⋯ 0 ⋮ ⋮ ⋱ ⋮ 0 0 ⋯ F n ] (4) \sum {^{-1}}=\begin{bmatrix} F_1 & 0 & \cdots & 0 \\ 0 & F_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & F_n\\ \end{bmatrix}\tag{4} 1=F1000F2000Fn(4)
    其中 n n n 表示参数的数目,而 { F i : i = 1 , ⋯   , n } \{F_i:i=1,\cdots,n\} {Fi:i=1,,n} 来自参数 θ A ∗ \theta_A^* θA 的 Fisher 信息矩阵的对角线元素,具体而言,就是 F i = ∣ ∂ L ( θ ) ∂ θ ∣ θ A ∗ ∣ 2 F_i=|\frac{\partial\mathcal{L}(\theta)}{\partial\theta}|_{\theta_A^*}|^2 Fi=θL(θ)θA2。将上面的假设带入(3),经过简单的计算,可得最终的EWC的正则形式,也即上面提到的公式(1):
    L ( θ ) = L B ( θ ) + ∑ i λ 2 F i ( θ i − θ A , i ∗ ) 2 \mathcal{L}(\theta)=\mathcal{L}_{B}(\theta) + \sum_{i}\frac{\lambda}{2}F_{i}(\theta_{i} - \theta_{A,i}^*) ^2 L(θ)=LB(θ)+i2λFi(θiθA,i)2

以上就是EWC中关于参数重要性计算的内容。我们主要分析参数正则方法在连续学习中的应用的原理或者思想,其他的试验内容参见原文[1]的具体试验分析。

Remarks: 1)如果在公式(1)中,直接使用 L 2 L_{2} L2 正则 Task A的参数 θ A ∗ \theta_{A}^* θA,则相当于各个参数 θ A , i ∗ \theta_{A,i}^* θA,i 的重要性是同等的。因此 L 2 L_2 L2 正则可以看做是 EWC 的一个特殊情形。2)另外,EWC 需要在每个任务训练结束之后,更新各个参数的重要性。


结束语:虽然 EWC 在一定程度上能够减轻模型的遗忘性问题,但是重要参数还是会慢慢偏离之前任务学习到的最优参数,从长远来看还是会出现遗忘的情形,特别是在序列处理的任务数量不断增加时,这也是所谓的model drift 问题。因此关于连续学习,如何有效地用在神经网络中,还有很多的问题需要解决。


References

[1] Overcoming catastrophic forgetting in neural networks.
[2] A continual learning survey: Defying forgetting in classification tasks.
[3] A practical Bayesian framework for backpropagation networks.

你可能感兴趣的:(深度学习,深度学习,神经网络,机器学习,概率论)