持续学习:(Elastic Weight Consolidation, EWC)Overcoming Catastrophic Forgetting in Neural Network

EWC目录

  • 概述
  • 1. 基础知识
    • 1.1 基本概念
    • 1.2 贝叶斯法则
  • 2. Elastic Weight Consolidation
    • 2.1 参数定义
    • 2.2 EWC 方法推导
  • 3. 拉普拉斯近似
    • 3.1 高斯分布拟合
    • 3.2 Fisher Information Matrix
      • 3.2.1 Fisher Information Matrix 的含义
      • 3.2.2 Fisher 信息矩阵与 Hessian 矩阵

概述

原论文地址:https://arxiv.org/pdf/1612.00796.pdf

本博客参考了以下博客的理解
地址:https://blog.csdn.net/dhaiuda/article/details/103967676/

本博客仅是个人对此论文的理解,若有理解不当的地方欢迎大家指正。

本篇论文讲述了一种通过给权重添加正则,从而控制权重优化方向,从而达到持续学习效果的方法。其方法简单来讲分为以下三个步骤,其思想如图所示:

  • 选择出对于旧任务(old task)比较重要的权重
  • 对权重的重要程度进行排序
  • 在优化的时候,越重要的权重改变越小,保证其在小范围内改变,不会对旧任务产生较大的影响
    持续学习:(Elastic Weight Consolidation, EWC)Overcoming Catastrophic Forgetting in Neural Network_第1张图片
    在图中,灰色区域时旧任务的低误差区域,白色为新任务的低误差区域。如果用旧任务的权重初始化网络,用新任务的数据进行训练的话,优化的方向如蓝色箭头所示,离开了灰色区域,代表着其网络失去了在旧任务上的性能。通过控制优化方向,使得其能够处于两个区域的交集部分,便代表其在旧任务与新任务上都有良好的性能。

具体方法为:将模型的后验概率拟合为一个高斯分布,其中均值为旧任务的权重,方差为 Fisher 信息矩阵(Fisher Information Matrix)的对角元素的倒数。方差就代表了每个权重的重要程度。

1. 基础知识

1.1 基本概念

  • 灾难性遗忘(Catastrophic Forgetting):在网络顺序训练多个任务的时候,对于先前任务的重要权重无法保留。灾难性遗忘是网络结构的必然特征
  • 持续学习:在顺序学习任务的时候,不忘记之前训练过的任务。根据任务A训练网络之后,再根据任务B训练同一个网络,此时对任务A进行测试,还可以维持其性能。

1.2 贝叶斯法则

P ( A ∣ B ) = P ( A ∩ B ) P ( B ) P(A|B) = \frac{P(A \cap B)}{P(B)} P(AB)=P(B)P(AB)
P ( B ∣ A ) = P ( A ∩ B ) P ( A ) P(B|A) = \frac{P(A \cap B)}{P(A)} P(BA)=P(A)P(AB)

P ( A ∣ B ) P ( B ) = P ( B ∣ A ) P ( A ) P(A|B)P(B)=P(B|A)P(A) P(AB)P(B)=P(BA)P(A)
所以可以得到
P ( B ∣ A ) = P ( A ∣ B ) P ( B ) P ( A ) P(B|A) = P(A|B)\frac{P( B)}{P(A)} P(BA)=P(AB)P(A)P(B)

2. Elastic Weight Consolidation

2.1 参数定义

  • θ \theta θ:网络的参数
  • θ A ∗ \theta^*_A θA:对于任务A,网络训练得到的最优参数
  • D D D:全体数据集
  • D A D_A DA:任务 A 的数据集
  • D B D_B DB:任务 B 的数据集
  • F F F:Fisher 信息矩阵
  • H H H:Hessian 矩阵

2.2 EWC 方法推导

对于网络来讲,给定数据集,目的是寻找一个最优的参数,即
P ( θ ∣ D ) P(\theta|D) P(θD)
根据贝叶斯准则
P ( B ∣ A ) = P ( A ∣ B ) P ( B ) P ( A ) P(B|A) = P(A|B)\frac{P( B)}{P(A)} P(BA)=P(AB)P(A)P(B)
可以得到最大后验概率:
P ( θ ∣ D ) = P ( D ∣ θ ) P ( θ ) P ( D ) P(\theta|D) = P(D|\theta)\frac{P( \theta)}{P(D)} P(θD)=P(Dθ)P(D)P(θ)
于是可以得到
log ⁡ P ( θ ∣ D ) = log ⁡ ( P ( D ∣ θ ) P ( θ ) P ( D ) ) = log ⁡ P ( D ∣ θ ) + log ⁡ P ( θ ) − log ⁡ P ( D ) \log P(\theta|D) = \log (P(D|\theta)\frac{P( \theta)}{P(D)})=\log P(D|\theta) + \log P( \theta) - \log P(D) logP(θD)=log(P(Dθ)P(D)P(θ))=logP(Dθ)+logP(θ)logP(D)
也就是论文中的公式(1)

如果这是两个任务的顺序学习,旧任务为任务 A,新任务为任务 B,那么可以数据集 D D D 可以划分为 D A D_A DA D B D_B DB,则
P ( θ ∣ D A , D B ) = P ( θ , D A , D B ) P ( D A , D B ) = P ( θ , D B ∣ D A ) P ( D A ) P ( D B ∣ D A ) P ( D A ) = P ( θ , D B ∣ D A ) P ( D B ∣ D A ) P(\theta|D_A,D_B)=\frac{P(\theta,D_A,D_B)}{P(D_A,D_B)}=\frac{P(\theta,D_B|D_A)P(D_A)}{P(D_B|D_A)P(D_A)}=\frac{P(\theta,D_B|D_A)}{P(D_B|D_A)} P(θDA,DB)=P(DA,DB)P(θ,DA,DB)=P(DBDA)P(DA)P(θ,DBDA)P(DA)=P(DBDA)P(θ,DBDA)
又因为
P ( θ , D B ∣ D A ) = P ( θ , D B , D A ) = P ( θ , D A , D B ) P ( D A ) = P ( θ , D A , D B ) P ( θ , D A ) ⋅ P ( θ , D A ) P ( D A ) = P ( D B ∣ θ , D A ) P ( θ ∣ D A ) P(\theta,D_B|D_A)=P(\theta,D_B,D_A)=\frac{P(\theta,D_A,D_B)}{P(D_A)}=\frac{P(\theta,D_A,D_B)}{P(\theta,D_A)} \cdot \frac{P(\theta,D_A)}{P(D_A)}=P(D_B|\theta,D_A)P(\theta|D_A) P(θ,DBDA)=P(θ,DB,DA)=P(DA)P(θ,DA,DB)=P(θ,DA)P(θ,DA,DB)P(DA)P(θ,DA)=P(DBθ,DA)P(θDA)
所以,可以得到
P ( θ ∣ D A , D B ) = P ( θ , D B ∣ D A ) P ( D B ∣ D A ) = P ( D B ∣ θ , D A ) P ( θ ∣ D A ) P ( D B ∣ D A ) P(\theta|D_A,D_B)=\frac{P(\theta,D_B|D_A)}{P(D_B|D_A)}=\frac{P(D_B|\theta,D_A)P(\theta|D_A)}{P(D_B|D_A)} P(θDA,DB)=P(DBDA)P(θ,DBDA)=P(DBDA)P(DBθ,DA)P(θDA)
又因为 D A D_A DA D B D_B DB 独立,所以可以得到
P ( D B ∣ D A ) = P ( D B ) P(D_B|D_A)=P(D_B) P(DBDA)=P(DB)
P ( D B ∣ θ , D A ) = P ( D B ∣ θ ) P(D_B|\theta,D_A)=P(D_B|\theta) P(DBθ,DA)=P(DBθ)
所以
P ( θ ∣ D A , D B ) = P ( D B ∣ θ ) P ( θ ∣ D A ) P ( D B ) P(\theta|D_A,D_B)=\frac{P(D_B|\theta)P(\theta|D_A)}{P(D_B)} P(θDA,DB)=P(DB)P(DBθ)P(θDA)
同样对于两边取 log,可以得到
log ⁡ P ( θ ∣ D ) = log ⁡ P ( θ ∣ D A , D B ) = log ⁡ P ( D B ∣ θ ) + l o g P ( θ ∣ D A ) − log ⁡ P ( D B ) \log P(\theta|D)=\log P(\theta|D_A,D_B)= \log P(D_B|\theta)+logP(\theta|D_A)-\log P(D_B) logP(θD)=logP(θDA,DB)=logP(DBθ)+logP(θDA)logP(DB)
这个便是论文中的公式(2),也是这篇论文的核心内容。

在给定整个数据集,我们需要得到一个 θ \theta θ 使得概率最大,那么也就是分别优化上式的右边三项。

第一项很明显可以理解为任务B的损失函数,将其命名为 L B ( θ ) L_B(\theta) LB(θ),第三项对于 θ \theta θ 来讲是一个常数,那么网络的优化目标便是
m a x θ log ⁡ P ( θ ∣ D ) = m a x θ ( − L B ( θ ) + log ⁡ P ( θ ∣ D A ) ) \mathop{max}\limits_{\theta}\log P(\theta|D)=\mathop{max}\limits_{\theta}(-L_B(\theta)+\log P(\theta|D_A)) θmaxlogP(θD)=θmax(LB(θ)+logP(θDA))

m i n θ ( L B ( θ ) − log ⁡ P ( θ ∣ D A ) ) \mathop{min}\limits_{\theta}(L_B(\theta)-\log P(\theta|D_A)) θmin(LB(θ)logP(θDA))
现在,重点变成了如何优化后验概率 log ⁡ P ( θ ∣ D A ) \log P(\theta|D_A) logP(θDA) ,作者采用了拉普拉斯近似的方法进行量化。

3. 拉普拉斯近似

由于后验概率并不容易进行衡量,所以我们将其先验 log ⁡ P ( D A ∣ θ ) \log P(D_A|\theta) logP(DAθ) 拟合为一个高斯分布

3.1 高斯分布拟合

令先验 log ⁡ P ( D A ∣ θ ) \log P(D_A|\theta) logP(DAθ) 服从高斯分布
P ( D A ∣ θ ) ∼ N ( μ , σ ) P(D_A|\theta) \sim N(\mu, \sigma) P(DAθ)N(μ,σ)
那么由高斯分布的公式可以得到
P ( D A ∣ θ ) = 1 2 π σ e − ( θ − μ ) 2 2 σ 2 P(D_A|\theta)=\frac{1}{\sqrt{2 \pi}\sigma} e^{-\frac{(\theta-\mu)^2}{2\sigma^2}} P(DAθ)=2π σ1e2σ2(θμ)2
那么,可以得到
log ⁡ P ( D A ∣ θ ) = log ⁡ 1 2 π σ − ( θ − μ ) 2 2 σ 2 \log P(D_A|\theta)=\log \frac{1}{\sqrt{2 \pi}\sigma} -\frac{(\theta-\mu)^2}{2\sigma^2} logP(DAθ)=log2π σ12σ2(θμ)2

f ( θ ) = log ⁡ P ( D A ∣ θ ) f(\theta)=\log P(D_A|\theta) f(θ)=logP(DAθ)
θ = θ A ∗ \theta = \theta_A^* θ=θA 处进行泰勒展开,可以得到
f ′ ( θ A ∗ ) = 0 f'(\theta_A^*)=0 f(θA)=0
f ( θ ) = f ( θ A ∗ ) + f ′ ( θ A ∗ ) ( θ − θ A ∗ ) + f ′ ′ ( θ A ∗ ) ( θ − θ A ∗ ) 2 2 + o ( θ A ∗ ) f(\theta)=f(\theta_A^*)+f'(\theta_A^*)(\theta-\theta_A^*)+f''(\theta_A^*)\frac{(\theta-\theta_A^*)^2}{2}+o(\theta_A^*) f(θ)=f(θA)+f(θA)(θθA)+f(θA)2(θθA)2+o(θA)
所以
log ⁡ 1 2 π σ − ( θ − μ ) 2 2 σ 2 ≈ f ( θ A ∗ ) + f ′ ′ ( θ A ∗ ) ( θ − θ A ∗ ) 2 2 \log \frac{1}{\sqrt{2 \pi}\sigma} -\frac{(\theta-\mu)^2}{2\sigma^2}\approx f(\theta_A^*)+f''(\theta_A^*)\frac{(\theta-\theta_A^*)^2}{2} log2π σ12σ2(θμ)2f(θA)+f(θA)2(θθA)2
其中, log ⁡ 1 2 π σ \log \frac{1}{\sqrt{2 \pi}\sigma} log2π σ1 f ( θ A ∗ ) f(\theta_A^*) f(θA) 都是常数,可以得到
− ( θ − μ ) 2 2 σ 2 = f ′ ′ ( θ A ∗ ) ( θ − θ A ∗ ) 2 2 -\frac{(\theta-\mu)^2}{2\sigma^2}= f''(\theta_A^*)\frac{(\theta-\theta_A^*)^2}{2} 2σ2(θμ)2=f(θA)2(θθA)2
因此,可以得到
μ = θ A ∗ \mu = \theta_A^* μ=θA
σ 2 = − 1 f ′ ′ ( θ A ∗ ) \sigma^2=-\frac{1}{f''(\theta_A^*)} σ2=f(θA)1
所以,可以得到
P ( D A ∣ θ ) ∼ N ( θ A ∗ , − 1 f ′ ′ ( θ A ∗ ) ) P(D_A|\theta) \sim N(\theta_A^*, -\frac{1}{f''(\theta_A^*)}) P(DAθ)N(θA,f(θA)1)
根据贝叶斯准则,
P ( θ ∣ D A ) = P ( P A , θ ) P ( θ ) P ( A ) P(\theta|D_A) = \frac{P(P_A,\theta)P(\theta)}{P(A)} P(θDA)=P(A)P(PA,θ)P(θ)
其中, P ( θ ) P(\theta) P(θ) 符合均匀分布, P ( D A ) P(D_A) P(DA) 为常数,所以
P ( θ ∣ D A ) ∼ N ( θ A ∗ , − 1 f ′ ′ ( θ A ∗ ) ) P(\theta|D_A) \sim N(\theta_A^*, -\frac{1}{f''(\theta_A^*)}) P(θDA)N(θA,f(θA)1)
此时,优化函数
m i n θ ( L B ( θ ) − log ⁡ P ( θ ∣ D A ) ) \mathop{min}\limits_{\theta}(L_B(\theta)-\log P(\theta|D_A)) θmin(LB(θ)logP(θDA))
可以变换为
m i n θ ( L B ( θ ) − f ′ ′ ( θ A ∗ ) ( θ − θ A ∗ ) 2 2 \mathop{min}\limits_{\theta}(L_B(\theta)- f''(\theta_A^*)\frac{(\theta-\theta_A^*)^2}{2} θmin(LB(θ)f(θA)2(θθA)2
对于一个batch来说,即为
m i n θ ( L B ( θ ) − ∑ i f i ′ ′ ( θ A ∗ ) ( θ i − θ A , i ∗ ) 2 2 \mathop{min}\limits_{\theta}(L_B(\theta)- \sum_i f''_i(\theta_A^*)\frac{(\theta_i-\theta_{A,i}^*)^2}{2} θmin(LB(θ)ifi(θA)2(θiθA,i)2
那么 f ′ ′ ( θ A ∗ ) f''(\theta_A^*) f(θA) 该如何求解呢?

3.2 Fisher Information Matrix

3.2.1 Fisher Information Matrix 的含义

Fisher information 是概率分布梯度的协方差。为了更好的说明Fisher Information matrix 的含义,这里定义一个得分函数 S S S
S ( θ ) = ∇ log ⁡ p ( x ∣ θ ) S(\theta)=\nabla \log p(x|\theta) S(θ)=logp(xθ)

E p ( x ∣ θ ) [ S ( θ ) ] = E p ( x ∣ θ ) [ ∇ log ⁡ p ( x ∣ θ ) ] = ∫ ∇ log ⁡ p ( x ∣ θ ) ⋅ p ( x ∣ θ ) d θ = ∫ ∇ p ( x ∣ θ ) p ( x ∣ θ ) ⋅ p ( x ∣ θ ) d θ = ∫ ∇ p ( x ∣ θ ) d θ = ∇ ∫ p ( x ∣ θ ) d θ = ∇ 1 = 0 \begin{aligned} \mathop{E}\limits_{p(x|\theta)}[S(\theta)] &=\mathop{E}\limits_{p(x|\theta)}[\nabla \log p(x|\theta)] \\ &= \int \nabla \log p(x|\theta) \cdot p(x|\theta) d\theta \\ &= \int \frac{\nabla p(x|\theta)}{p(x|\theta)} \cdot p(x|\theta) d\theta \\ &= \int \nabla p(x|\theta) d\theta \\ &= \nabla \int p(x|\theta) d\theta \\ & = \nabla 1=0 \end{aligned} p(xθ)E[S(θ)]=p(xθ)E[logp(xθ)]=logp(xθ)p(xθ)dθ=p(xθ)p(xθ)p(xθ)dθ=p(xθ)dθ=p(xθ)dθ=1=0
那么 Fisher Information matrix F F F
F = E p ( X ∣ θ ) [ ( S ( θ ) − 0 ) ( S ( θ ) − 0 ) T ] F = \mathop{E}\limits_{p(X|\theta)}[(S(\theta)-0)(S(\theta)-0)^T] F=p(Xθ)E[(S(θ)0)(S(θ)0)T]
对于每一个batch的数据 X = { x 1 , x 2 , ⋯   , x n } X = \{x_1,x_2,\cdots ,x_n\} X={x1,x2,,xn},则其定义为
F = 1 N ∑ i = 1 N ∇ log ⁡ p ( x i ∣ θ ) ∇ log ⁡ p ( x i ∣ θ ) T F = \frac{1}{N}\sum_{i=1}^N \nabla \log p(x_i|\theta) \nabla \log p(x_i|\theta)^T F=N1i=1Nlogp(xiθ)logp(xiθ)T

3.2.2 Fisher 信息矩阵与 Hessian 矩阵

Hessian矩阵为
H log ⁡ p ( x ∣ θ ) = J ( ∇ log ⁡ p ( x ∣ t h e t a ) ) = J ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) = H p ( x ∣ θ ) p ( x ∣ θ ) − ∇ p ( x ∣ θ ) ∇ p ( x ∣ θ ) T p ( x ∣ θ ) p ( x ∣ θ ) = H p ( x ∣ θ ) p ( x ∣ θ ) − ∇ p ( x ∣ θ ) ∇ p ( x ∣ θ ) T p ( x ∣ θ ) p ( x ∣ θ ) = H p ( x ∣ θ ) p ( x ∣ θ ) − ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) T \begin{aligned} H_{\log p(x|\theta)} &= J(\nabla \log p(x|theta)) = J(\frac{ \nabla p(x|\theta)}{ p(x|\theta)}) \\ &= \frac{H_{ p(x|\theta)} p(x|\theta)- \nabla p(x|\theta) \nabla p(x|\theta)^T}{ p(x|\theta) p(x|\theta)} \\ &= \frac{H_{ p(x|\theta)}}{ p(x|\theta)}-\frac{ \nabla p(x|\theta) \nabla p(x|\theta)^T}{ p(x|\theta) p(x|\theta)} \\ &= \frac{H_{ p(x|\theta)}}{ p(x|\theta)}-(\frac{ \nabla p(x|\theta)}{ p(x|\theta) })(\frac{ \nabla p(x|\theta)}{ p(x|\theta) })^T \end{aligned} Hlogp(xθ)=J(logp(xtheta))=J(p(xθ)p(xθ))=p(xθ)p(xθ)Hp(xθ)p(xθ)p(xθ)p(xθ)T=p(xθ)Hp(xθ)p(xθ)p(xθ)p(xθ)p(xθ)T=p(xθ)Hp(xθ)(p(xθ)p(xθ))(p(xθ)p(xθ))T
Fisher 信息阵为
E p ( x ∣ θ ) [ H log ⁡ p ( x ∣ θ ) ] = E p ( x ∣ θ ) [ H p ( x ∣ θ ) p ( x ∣ θ ) ] − E p ( x ∣ θ ) [ ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) ( ∇ p ( x ∣ θ ) p ( x ∣ θ ) ) T ] = ∫ H p ( x ∣ θ ) p ( x ∣ θ ) p ( x ∣ θ ) d θ − E p ( x ∣ θ ) [ ∇ log ⁡ p ( x ∣ θ ) ∇ log ⁡ p ( x ∣ θ ) T ] = ∫ H p ( x ∣ θ ) p ( x ∣ θ ) p ( x ∣ θ ) d θ − E p ( x ∣ θ ) [ ( S ( θ ) − 0 ) ( S ( θ ) − 0 ) T ] = ∫ H p ( x ∣ θ ) d θ − F = H 1 − F = 0 − F = − F \begin{aligned} \mathop{E}\limits_{p(x|\theta)}[H_{\log p(x|\theta)}] &= \mathop{E}\limits_{p(x|\theta)}[\frac{H_{ p(x|\theta)}}{ p(x|\theta)}]- \mathop{E}\limits_{p(x|\theta)}[(\frac{ \nabla p(x|\theta)}{p(x|\theta) })(\frac{ \nabla p(x|\theta)}{ p(x|\theta) })^T] \\ &= \int \frac{H_{ p(x|\theta)}}{ p(x|\theta)} p(x|\theta) d\theta - \mathop{E}\limits_{p(x|\theta)}[\nabla \log p(x|\theta) \nabla \log p(x|\theta)^T] \\ & = \int \frac{H_{ p(x|\theta)}}{ p(x|\theta)} p(x|\theta) d\theta - \mathop{E}\limits_{p(x|\theta)}[(S(\theta)-0)(S(\theta)-0)^T] \\ &= \int {H_{ p(x|\theta)}} d\theta -F \\ &= H_1 -F = 0-F \\ &=-F \end{aligned} p(xθ)E[Hlogp(xθ)]=p(xθ)E[p(xθ)Hp(xθ)]p(xθ)E[(p(xθ)p(xθ))(p(xθ)p(xθ))T]=p(xθ)Hp(xθ)p(xθ)dθp(xθ)E[logp(xθ)logp(xθ)T]=p(xθ)Hp(xθ)p(xθ)dθp(xθ)E[(S(θ)0)(S(θ)0)T]=Hp(xθ)dθF=H1F=0F=F
所以,Fisher 信息矩阵是 Hessian 矩阵的负期望。

因为 f ′ ′ ( x ) f''(x) f(x) H H H 的对角线元素,所以 − 1 f ′ ′ ( x ) -\frac{1}{f''(x)} f(x)1 F F F对角线元素的倒数。
所以,损失函数
m i n θ ( L B ( θ ) − ∑ i f i ′ ′ ( θ A ∗ ) ( θ i − θ A , i ∗ ) 2 2 \mathop{min}\limits_{\theta}(L_B(\theta)- \sum_i f''_i(\theta_A^*)\frac{(\theta_i-\theta_{A,i}^*)^2}{2} θmin(LB(θ)ifi(θA)2(θiθA,i)2
可以变为
m i n θ ( L B ( θ ) − ∑ i F i ( θ i − θ A , i ∗ ) 2 2 \mathop{min}\limits_{\theta}(L_B(\theta)- \sum_i F_i\frac{(\theta_i-\theta_{A,i}^*)^2}{2} θmin(LB(θ)iFi2(θiθA,i)2
引入超参 λ \lambda λ 衡量两项的重要程度,可以得到最终的损失
m i n θ ( L B ( θ ) − λ 2 ∑ i F i ( θ i − θ A , i ∗ ) 2 \mathop{min}\limits_{\theta}(L_B(\theta)- \frac{\lambda}{2}\sum_i F_i(\theta_i-\theta_{A,i}^*)^2 θmin(LB(θ)2λiFi(θiθA,i)2
上式即为论文中的公式(3)
到此,论文的核心内容就已经结束了,后面的应用及实验结果在此不再展示。

你可能感兴趣的:(持续学习)