[持续学习] Fisher信息矩阵与EWC

文章目录

  • 一、前置知识
    • 1. 得分函数 score / informant
    • 2. Fisher信息矩阵
  • 二、EWC
    • 1. 数学推导
    • 2. 如何计算 Fisher 信息矩阵

一、前置知识

1. 得分函数 score / informant

score / informant 定义为对数似然函数关于参数的梯度:

s ( θ ) ≡ ∂ log ⁡ L ( θ ) ∂ θ s(\theta) \equiv \frac{\partial\log{\mathcal{L}(\theta)}}{\partial\theta} s(θ)θlogL(θ)

其中 L ( θ ) \mathcal{L}(\theta) L(θ)即为似然函数,可扩写为 L ( θ ∣ x ) \mathcal{L}(\theta |x) L(θx),其中 x x x为观测到的数据, x x x从采样域 X \mathcal{X} X中产生

在某一特定点的 s s s 函数指明了该点处对数似然函数的陡峭程度(steepness),或者是函数值对参数发生无穷小量变化的敏感性。

如果对数似然函数定义在连续实数值的参数空间中,那么它的函数值将在(局部)极大值与极小值点消失。这一性质通常用于极大似然估计中(maximum likelihood estimation, MLE),来寻找使得似然函数值极大的参数值。


注意 L ( θ ∣ x ) \mathcal{L}(\theta |x) L(θx)中竖线前后的字母 θ ∣ x \theta|x θx x x x为随机变量,在这里则是一个定值,意为采样后的观测值,而 θ \theta θ则为自变量,意为参数模型中的参数

当(假设 θ \theta θ位于正确值时,我们可以通过 θ \theta θ推导 x x x,也就是 f ( x ∣ θ ) f(x|\theta) f(xθ) ,为一概率密度函数,意为当模型参数为 θ \theta θ时,采样到 x x x的概率

从两个角度得到了对同一事实的论证,因此可写作 f ( x ∣ θ ) = L ( θ ∣ x ) f(x|\theta) = \mathcal{L}(\theta | x) f(xθ)=L(θx)


首先,来分析 s s s的数学期望,这里讨论的问题是:当参数取值为 θ \theta θ时, s ∣ θ s|\theta sθ的数学期望

从直观上分析,当参数位于真实最佳)参数点时,似然函数有其极大值(考虑极大似然估计的定义),因此为一极值点,所以该点梯度为 0 0 0,即 E [ s ∣ θ ] = 0 \mathbb{E}[s|\theta]= 0 E[sθ]=0

下面进行公式分析:

首先要明确,该期望是 s s s函数关于什么随机变量的期望。从上面的讨论中可以得到,该问题中唯一的随机变量是采样观测值 x x x,它的采样概率是 f ( x ∣ θ ) f(x|\theta) f(xθ)

注意:

f ( x ) ∂ log ⁡ f ( x ) ∂ x = f ( x ) 1 f ( x ) ∂ f ( x ) ∂ x = ∂ f ( x ) ∂ x \begin{aligned} & f(x) \frac{\partial\log{f(x)}}{\partial{x}} \\ = & {f(x)} \frac{1}{f(x)} \frac{\partial{f(x)}}{\partial{x}} \\ = & \frac{\partial{f(x)}}{\partial{x}} \end{aligned} ==f(x)xlogf(x)f(x)f(x)1xf(x)xf(x)

所以:

E [ s ∣ θ ] = ∫ X f ( x ∣ θ ) ⋅ s ⋅ d x = ∫ X f ( x ∣ θ ) ∂ log ⁡ L ( θ ∣ x ) ∂ θ d x = ∫ X f ( x ∣ θ ) ∂ log ⁡ f ( x ∣ θ ) ∂ θ d x = ∫ X ∂ f ( x ∣ θ ) ∂ θ d x = ∂ ∂ x ∫ X f ( x ∣ θ ) d x = ∂ ∂ x 1 = 0 ■ \begin{aligned} \mathbb{E}[s|\theta] & = \int_{\mathcal{X}}f(x|\theta)\cdot{}s\cdot{}\mathrm{d}x \\ & = \int_{\mathcal{X}}f(x|\theta) \frac{\partial\log{\mathcal{L}(\theta|x)}}{\partial\theta} \mathrm{d}x \\ & = \int_{\mathcal{X}}f(x|\theta) \frac{\partial\log{f(x|\theta)}}{\partial\theta} \mathrm{d}x \\ & = \int_{\mathcal{X}} \frac{\partial f(x|\theta)}{\partial{\theta}} \mathrm{d}x \\ & = \frac{\partial}{\partial{x}} \int_{\mathcal{X}}f(x|\theta)\mathrm{d}x \\ & = \frac{\partial}{\partial{x}} 1 \\ & = 0\qquad\blacksquare \\ \end{aligned} E[sθ]=Xf(xθ)sdx=Xf(xθ)θlogL(θx)dx=Xf(xθ)θlogf(xθ)dx=Xθf(xθ)dx=xXf(xθ)dx=x1=0

因此得证: E [ s ∣ θ ] = 0 \mathbb{E}[s|\theta]= 0 E[sθ]=0

2. Fisher信息矩阵

Fisher信息(Fisher information),或简称为信息(information)是一种衡量信息量的指标

假设我们想要建模一个随机变量 x x x 的分布,用于建模的参数是 θ \theta θ,那么Fisher信息测量了 x x x 携带的对于 θ \theta θ 的信息量

所以,当我们固定 θ \theta θ 值,以 x x x 为自变量,Fisher 信息应当指出这一 x x x 值可贡献给 θ \theta θ 多少信息量

比如说,某一 θ \theta θ 点附近的函数平面非常陡峭(有一极值峰值),那么我们不需要采样多少 x x x 即可做出比较好的估计,也就是采样点 x x x 的Fisher 信息量较高。反之,若某一 θ \theta θ 附近的函数平面连续且平缓,那么我们需要采样很多点才能做出比较好的估计,也就是 Fisher 信息量较低。

从这一直观定义出发,我们可以联想到随机变量的方差,因此对于一个(假设的)真实参数 θ \theta θ s s s 函数的 Fisher 信息定义为 s s s 函数的方差

I ( θ ) = E [ ( ∂ ∂ θ log ⁡ f ( x ∣ θ ) ) 2 ∣ θ ] = ∫ ( ∂ ∂ θ log ⁡ f ( x ∣ θ ) ) 2 f ( x ; θ ) d x \begin{aligned} \mathcal{I} (\theta) & =\mathbb{E}\left[\left.\left({\frac {\partial }{\partial \theta }}\log f(x|\theta )\right)^{2}\right|\theta \right] \\ & = \int \left({\frac {\partial }{\partial \theta }}\log f(x|\theta )\right)^{2}f(x;\theta )\mathrm{d}x \end{aligned} I(θ)=E[(θlogf(xθ))2θ]=(θlogf(xθ))2f(x;θ)dx

此外,如果 log ⁡ f ( x ∣ θ ) \log f(x|\theta) logf(xθ) 对于 θ \theta θ 二次可微,那么 Fisher 信息还可以写作

I ( θ ) = − E [ ∂ 2 ∂ 2 θ log ⁡ f ( x ∣ θ ) ∣ θ ] \mathcal{I}(\theta) = -\mathbb{E}\left[\left.{\frac {\partial^2}{\partial^2 \theta }}\log f(x|\theta )\right|\theta \right] I(θ)=E[2θ2logf(xθ)θ]

证明如下:

∵ 0 = E [ s ∣ θ ] ∴ 0 = ∂ ∂ θ E [ s ∣ θ ] = ∂ ∂ θ ∫ X f ( x ∣ θ ) ∂ log ⁡ L ( θ ∣ x ) ∂ θ d x = ∫ X ∂ ∂ θ ∂ log ⁡ L ( θ ∣ x ) ∂ θ f ( x ∣ θ )   d x ▹  use chain rule = ∫ X { ∂ 2 log ⁡ L ( θ ∣ x ) ∂ 2 θ f ( x ∣ θ ) + ∂ f ( x ∣ θ ) ∂ θ ∂ log ⁡ L ( θ ∣ x ) ∂ θ } d x = ∫ X ∂ 2 log ⁡ L ( θ ∣ x ) ∂ 2 θ f ( x ∣ θ ) d x ⏟ A + ∫ X ∂ L ( θ ∣ x ) ∂ θ ∂ log ⁡ L ( θ ∣ x ) ∂ θ d x ⏟ B A = E [ ∂ 2 log ⁡ L ( θ ∣ x ) ∂ 2 θ ∣ θ ] B = ∫ X ∂ L ( θ ∣ x ) ∂ θ ∂ log ⁡ L ( θ ∣ x ) ∂ θ d x = ∫ X ∂ log ⁡ L ( θ ∣ x ) ∂ θ L ( θ ∣ x ) ∂ log ⁡ L ( θ ∣ x ) ∂ θ d x = ∫ X ( ∂ log ⁡ L ( θ ∣ x ) ∂ θ ) 2 f ( x ∣ θ ) d x = E [ ( ∂ log ⁡ L ( θ ∣ x ) ∂ θ ) 2 ∣ θ ] ∵ A + B = 0 ∴ E [ ∂ 2 log ⁡ L ( θ ∣ x ) ∂ 2 θ ∣ θ ] + E [ ( ∂ log ⁡ L ( θ ∣ x ) ∂ θ ) 2 ∣ θ ] = 0 \begin{aligned} \because 0 & = \mathbb{E}[s|\theta] \\ \\ \therefore 0 & = \frac{\partial}{\partial\theta} \mathbb{E}[s|\theta] \\ & = \frac{\partial}{\partial\theta}\int_{\mathcal{X}} f(x|\theta) \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x \\ & = \int_{\mathcal{X}} \frac{\partial}{\partial\theta} \boxed{\frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} f(x|\theta)}\ \mathrm{d}x \quad{\triangleright\ \textrm{use chain rule}}\\ & = \int_{\mathcal{X}} \left\{ \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}f(x|\theta) + \frac{\partial f(x|\theta)}{\partial\theta} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}}\right\} \mathrm{d}x \\ & = \underbrace{\int_{\mathcal{X}} \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}f(x|\theta) \mathrm{d}x }_\mathbf{A} + \underbrace{\int_{\mathcal{X}}\frac{\partial \mathcal{L}(\theta|x)}{\partial\theta} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x}_{\mathbf{B}} \\ \\ \mathbf{A} &= \mathbb{E}\left[\left. \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}\right| \theta \right] \\ \mathbf{B} &= \int_{\mathcal{X}} \red{\frac{\partial \mathcal{L}(\theta|x)}{\partial\theta}} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x \\ &= \int_{\mathcal{X}}\red{\frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}}\mathcal{L}(\theta|x)} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x\\ &= \int_{\mathcal{X}} \left(\frac{ \partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \right)^2 f(x|\theta)\mathrm{d}x \\ &= \mathbb{E}\left[\left. \left(\frac{ \partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \right)^2 \right| \theta \right] \\ \\ & \because \mathbf{A}+\mathbf{B} = 0 \\ & \therefore \mathbb{E}\left[\left. \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}\right| \theta \right] + \mathbb{E}\left[\left. \left(\frac{ \partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \right)^2 \right| \theta \right] = 0 \end{aligned} 00AB=E[sθ]=θE[sθ]=θXf(xθ)θlogL(θx)dx=XθθlogL(θx)f(xθ) dx use chain rule=X{2θ2logL(θx)f(xθ)+θf(xθ)θlogL(θx)}dx=A X2θ2logL(θx)f(xθ)dx+B XθL(θx)θlogL(θx)dx=E[2θ2logL(θx)θ]=XθL(θx)θlogL(θx)dx=XθlogL(θx)L(θx)θlogL(θx)dx=X(θlogL(θx))2f(xθ)dx=E[(θlogL(θx))2θ]A+B=0E[2θ2logL(θx)θ]+E[(θlogL(θx))2θ]=0

二、EWC

1. 数学推导

假设数据集被划分为两个任务 Σ = { A , B } \Sigma = \{\mathcal{A, B}\} Σ={A,B},网络参数为 θ \theta θ

学习任务为最大化后验概率

arg max ⁡ θ P ( θ ∣ Σ ) = arg max ⁡ θ log ⁡ P ( θ ∣ Σ ) = arg min ⁡ θ l ( θ ) \begin{aligned} & \argmax_{\theta}P(\theta | \Sigma) \\ = & \argmax_{\theta} \log P(\theta | \Sigma) \\ = & \argmin_{\theta} l(\theta) \end{aligned} ==θargmaxP(θΣ)θargmaxlogP(θΣ)θargminl(θ)

其中 l ( θ ) l(\theta) l(θ)定义为训练 loss

考虑任务训练顺序 A ⇒ B \mathcal{A} \Rightarrow \mathcal{B} AB

log ⁡ P ( θ ∣ Σ ) = log ⁡ P ( θ ∣ A , B ) = log ⁡ P ( θ , A , B ) − log ⁡ P ( A , B ) = log ⁡ P ( B ∣ θ , A ) + log ⁡ P ( θ , A ) − log ⁡ P ( B ∣ A ) − log ⁡ P ( A ) = log ⁡ P ( B ∣ θ ) + log ⁡ P ( θ ∣ A ) + log ⁡ P ( A ) − log ⁡ P ( B ) − log ⁡ P ( A ) ▹ A, B i.i.d  = log ⁡ P ( B ∣ θ ) ⏟ l o s s   o n   B + log ⁡ P ( θ ∣ A ) ⏟ unknown − log ⁡ P ( B ) ⏟ constant \begin{aligned} & \log P(\theta | \Sigma) \\ &= \log P(\theta | A, B) \\ &= \log P(\theta, A, B) - \log P(A, B) \\ &= \log P(B|\theta, A) + \log P(\theta, A) - \log P(B|A) - \log P(A) \\ &= \log P(B|\theta) + \log P(\theta | A) + \log P(A) - \log P(B) - \log P(A) \qquad \triangleright \textrm{A, B i.i.d } \\ &= \underbrace{\log P(B|\theta)}_{\mathrm{loss\ on\ \mathcal{B}}} + \underbrace{\log P(\theta|A)}_{\textrm{unknown}} - \underbrace{\log P(B)}_{\textrm{constant}} \\ \end{aligned} logP(θΣ)=logP(θA,B)=logP(θ,A,B)logP(A,B)=logP(Bθ,A)+logP(θ,A)logP(BA)logP(A)=logP(Bθ)+logP(θA)+logP(A)logP(B)logP(A)A, B i.i.d =loss on B logP(Bθ)+unknown logP(θA)constant logP(B)

其中后验概率 log ⁡ P ( θ ∣ A ) \log P(\theta | A) logP(θA) 不易得到,因此使用拉普拉斯近似进行分析

在训练任务 B \mathcal{B} B 之前,网络已经在任务 A \mathcal{A} A 上收敛,设网络此时的参数为 θ A ∗ \theta_{A}^* θA,为在任务 A \mathcal{A} A 上拟合得到的参数,设函数 f ( θ ) = log ⁡ P ( θ ∣ A ) f(\theta) = \log P(\theta | A) f(θ)=logP(θA)

f ( θ ) f(\theta) f(θ) θ = θ A ∗ \theta = \theta_{A}^* θ=θA 做泰勒展开:

f ( θ ) = f ( θ A ∗ ) + ∂ f ( θ ) ∂ θ ∣ θ A ∗ ⏟ = 0 ( θ − θ A ∗ ) + 1 2 ( θ − θ A ∗ ) T ∂ 2 f ( θ ) ∂ 2 θ ∣ θ A ∗ ( θ − θ A ∗ ) + ⋯ ≈ f ( θ A ∗ ) + 1 2 ( θ − θ A ∗ ) T ∂ 2 f ( θ ) ∂ 2 θ ∣ θ A ∗ ( θ − θ A ∗ ) \begin{aligned} f(\theta) &= f(\theta_{A}^*) + \underbrace{\left.\frac{\partial f(\theta)}{\partial\theta}\right|_{\theta_A^*}}_{=0}(\theta - \theta_{A}^*) + \frac{1}{2}(\theta - \theta_{A}^*)^T \left.\frac{\partial^2f(\theta)}{\partial^2\theta}\right|_{\theta_A^*} (\theta - \theta_{A}^*) + \cdots \\ &\approx f(\theta_{A}^*) + \frac{1}{2}(\theta - \theta_{A}^*)^T \left.\frac{\partial^2f(\theta)}{\partial^2\theta}\right|_{\theta_A^*} (\theta - \theta_{A}^*) \\ \end{aligned} f(θ)=f(θA)+=0 θf(θ)θA(θθA)+21(θθA)T2θ2f(θ)θA(θθA)+f(θA)+21(θθA)T2θ2f(θ)θA(θθA)

f ( θ ) = log ⁡ P ( θ ∣ A ) f(\theta) = \log P(\theta | A) f(θ)=logP(θA) 代入:

log ⁡ P ( θ ∣ A ) = log ⁡ P ( θ A ∗ ∣ A ) + 1 2 ( θ − θ A ∗ ) T ∂ 2 log ⁡ P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ( θ − θ A ∗ ) = log ⁡ P ( θ A ∗ ∣ A ) + 1 2 ( θ − θ A ∗ ) T { − [ − ∂ 2 log ⁡ P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ] − 1 } − 1 ( θ − θ A ∗ ) P ( θ ∣ A ) = exp ⁡ [ Δ + 1 2 ( θ − θ A ∗ ) T { − [ − ∂ 2 log ⁡ P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ] − 1 } − 1 ( θ − θ A ∗ ) ] = ϵ exp ⁡ [ − 1 2 ( θ − θ A ∗ ) T { [ − ∂ 2 log ⁡ P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ] − 1 } − 1 ⏟ Σ − 1 ( θ − θ A ∗ ) ] where: Δ = log ⁡ P ( θ A ∗ ∣ A ) ϵ = exp ⁡ Δ \begin{aligned} \log P(\theta|A) &= \log P(\theta_A^* | A) + \frac{1}{2}(\theta - \theta_{A}^*)^T \left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*} (\theta - \theta_{A}^*) \\ &= \log P(\theta_A^* | A) + \frac{1}{2}(\theta - \theta_{A}^*)^T\left\{-\left[-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right]^{-1}\right\}^{-1} (\theta - \theta_{A}^*) \\ P(\theta|A) &= \exp{\left[ \Delta + \frac{1}{2}(\theta - \theta_{A}^*)^T\left\{-\left[-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right]^{-1}\right\}^{-1} (\theta - \theta_{A}^*) \right]} \\ &= \epsilon\exp{\left[-\frac{1}{2}(\theta - \theta_{A}^*)^T \underbrace{\left\{\left[-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right]^{-1}\right\}^{-1}}_{\Sigma^{-1}} (\theta - \theta_{A}^*) \right]} \\ \textbf{where:} & \\ \Delta &= \log P(\theta_A^* | A) \\ \epsilon &= \exp{\Delta} \end{aligned} logP(θA)P(θA)where:Δϵ=logP(θAA)+21(θθA)T2θ2logP(θA)θA(θθA)=logP(θAA)+21(θθA)T[2θ2logP(θA)θA]11(θθA)=expΔ+21(θθA)T[2θ2logP(θA)θA]11(θθA)=ϵexp21(θθA)TΣ1 [2θ2logP(θA)θA]11(θθA)=logP(θAA)=expΔ

观察形式可得:

P ( θ ∣ A ) ∼ N ( θ A ∗ , ( − ∂ 2 log ⁡ P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ) − 1 ) P(\theta | \mathcal{A}) \sim \mathcal{N}\left(\theta_{\mathcal{A}}^*,\left(-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right)^{-1}\right) P(θA)NθA,(2θ2logP(θA)θA)1

其中协方差矩阵项正是第一部分讨论的Fisher信息矩阵,记做 I A \mathbf{I}_{\mathcal{A}} IA,则有

P ( θ ∣ A ) ∼ N ( θ A ∗ , [ I A ] − 1 ) P(\theta | \mathcal{A}) \sim \mathcal{N}\left(\theta_{\mathcal{A}}^*,\left[\mathbf{I}_{\mathcal{A}}\right]^{-1} \right) P(θA)N(θA,[IA]1)

另外,EWC是以一个参数的视角出发的,因此Fisher信息矩阵只需要对角线元素,其余计算出来的结果可以置0,所以有:

P ( θ ∣ A ) = 1 ( 2 π ) k ∣ Σ ∣ exp ⁡ { − 1 2 ( θ − θ A ∗ ) T Σ − 1 ( θ − θ A ∗ ) } log ⁡ P ( θ i ∣ A ) = − 1 2 ( θ i − [ θ i ] A ∗ ) 2 ∗ [ Σ − 1 ] i i = − [ I A ] i i ( θ i − [ θ i ] A ∗ ) 2 2 \begin{aligned} P(\theta | A) &= \frac{1}{\sqrt{(2\pi)^k |\Sigma|}} \exp\{-\frac{1}{2}(\theta - \theta_A^*)^T\Sigma^{-1}(\theta - \theta_{A}^*)\}\\ \\ \log P(\theta_{i} | A) &= -\frac{1}{2}(\theta_i - [\theta_i]_A^*)^2 * [\Sigma^{-1}]_{ii} \\ &= -\left[\mathbf{I}_{\mathcal{A}}\right]_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \end{aligned} P(θA)logP(θiA)=(2π)kΣ 1exp{21(θθA)TΣ1(θθA)}=21(θi[θi]A)2[Σ1]ii=[IA]ii2(θi[θi]A)2

所以,所有参数的EWC Loss可定义为:

l EWC = − ∑ i = 1 #Params [ I A ] i i ( θ i − [ θ i ] A ∗ ) 2 2 \begin{aligned} l_{\textbf{EWC}} = -\sum_{i=1}^{\textrm{\#Params}} \left[\mathbf{I}_{\mathcal{A}}\right]_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \end{aligned} lEWC=i=1#Params[IA]ii2(θi[θi]A)2

将上述内容代入总优化目标:

l ( θ ) = log ⁡ P ( θ ∣ Σ ) = log ⁡ P ( B ∣ θ ) + log ⁡ P ( θ ∣ A ) − log ⁡ P ( B ) ⇒ l CE ( θ ∣ B ) + l EWC ( θ ∣ A ) \begin{aligned} l(\theta) &= \log P(\theta | \Sigma) \\ &= \log P(B|\theta) + \log P(\theta|A) - \log P(B) \\ &\Rightarrow l_{\textbf{CE}} (\theta | \mathcal{B}) + l_{\textbf{EWC}}(\theta | \mathcal{A}) \end{aligned} l(θ)=logP(θΣ)=logP(Bθ)+logP(θA)logP(B)lCE(θB)+lEWC(θA)

定义超参数 λ \lambda λ 进行稳定性-可塑性权衡

l ( θ ) = l CE ( θ ∣ B ) + λ ⋅ l EWC ( θ ∣ A ) \begin{aligned} l(\theta) &= l_{\textbf{CE}} (\theta | \mathcal{B}) + \lambda \cdot l_{\textbf{EWC}}(\theta | \mathcal{A}) \end{aligned} l(θ)=lCE(θB)+λlEWC(θA)

因此优化目标为:

arg min ⁡ θ l ( θ ) = arg min ⁡ θ { l CE ( θ ∣ B ) + λ ⋅ l EWC ( θ ∣ A ) } = arg min ⁡ θ { l CE ( θ ∣ B ) − λ 2 ∑ i = 1 #Params [ I A ] i i ( θ i − [ θ i ] A ∗ ) 2 2 } \begin{aligned} \argmin_{\theta}l(\theta) &= \argmin_{\theta}\left\{l_{\textbf{CE}} (\theta | \mathcal{B}) + \lambda \cdot l_{\textbf{EWC}}(\theta | \mathcal{A})\right\} \\ &= \argmin_{\theta} \left\{ l_{\textbf{CE}} (\theta | \mathcal{B}) - \frac{\lambda}{2}\sum_{i=1}^{\textrm{\#Params}} \left[\mathbf{I}_{\mathcal{A}}\right]_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \right\} \end{aligned} θargminl(θ)=θargmin{lCE(θB)+λlEWC(θA)}=θargmin{lCE(θB)2λi=1#Params[IA]ii2(θi[θi]A)2}

2. 如何计算 Fisher 信息矩阵

将训练过程进行划分:

  1. 使用数据集 A \mathcal{A} A l CE ( θ ∣ A ) l_{\textbf{CE}}(\theta|\mathcal{A}) lCE(θA) 训练模型
  2. 保存此时的参数,即 θ A ∗ \theta_{\mathcal{A}}^* θA ,并计算 Fisher 信息矩阵 I A \mathbf{I}_{\mathcal{A}} IA
  3. 使用数据集 B \mathcal{B} B l CE ( θ ∣ B ) l_{\textbf{CE}}(\theta|\mathcal{B}) lCE(θB) l EWC ( θ ∣ A ) l_{\textbf{EWC}}(\theta | \mathcal{A}) lEWC(θA) 训练模型
  4. 多任务 { A , B , C , … } \{\mathcal{A, B, C, \dots}\} {A,B,C,}同理

最后一个问题,如何使用 A \mathcal{A} A 计算 I A \mathbf{I}_{\mathcal{A}} IA

考虑定义

I ( θ ) = E [ ( ∂ ∂ θ log ⁡ f ( x ∣ θ ) ) 2 ∣ θ ] \mathcal{I}(\theta) = \mathbb{E}\left[\left.\left({\frac {\partial }{\partial \theta }}\log f(x|\theta )\right)^{2}\right|\theta \right] \\ I(θ)=E[(θlogf(xθ))2θ]

可以通过计算梯度的平方来获得每一个参数的 Fisher 信息矩阵项:

I ( θ ) = 1 N ∑ ( x , y ) i ∈ A ( ∂ l LL ( θ ∣ ( x , y ) i ) ∂ θ ⏟ Gradient ) 2 \mathcal{I}(\theta) = \frac{1}{N} \sum_{(x,y)_{i} \in \mathcal{A}}\left(\underbrace{\frac{\partial{l_{\textbf{LL}}(\theta|(x, y)_{i})}}{\partial\theta}}_{\textrm{Gradient}}\right)^2 I(θ)=N1(x,y)iAGradient θlLL(θ(x,y)i)2

具体来说,可以向模型逐个喂入样本,并计算损失函数,使用神经网络框架自动计算梯度。对于每个参数,累加所有的梯度,最后除以样本数量,即可得到对应参数的 Fisher 信息矩阵项

需要注意的是,当使用 nn.CrossEntrypyLossnn.NLLLoss时,由于其中对于 Log-Likelihood \textrm{Log-Likelihood} Log-Likelihood使用了相反数处理,使用该类损失函数得到的矩阵是真实 Fisher 信息矩阵的相反数,即 I ^ ( θ ) = − I ( θ ) \hat{\mathcal{I}}(\theta) = -\mathcal{I}(\theta) I^(θ)=I(θ),在计算 loss 时要记得将减号改成加号,即

arg min ⁡ θ l ( θ ) = arg min ⁡ θ { l CE ( θ ∣ B ) + λ 2 ∑ i = 1 #Params [ I A ] ^ i i ( θ i − [ θ i ] A ∗ ) 2 2 } \argmin_{\theta}l(\theta) = \argmin_{\theta} \left\{ l_{\textbf{CE}} (\theta | \mathcal{B}) + \frac{\lambda}{2}\sum_{i=1}^{\textrm{\#Params}} \red{\hat{\left[\mathbf{I}_{\mathcal{A}}\right]}}_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \right\} θargminl(θ)=θargmin{lCE(θB)+2λi=1#Params[IA]^ii2(θi[θi]A)2}

全文完

版权声明:自由转载-非商用-禁止演绎-保持署名 4.0 国际许可协议

你可能感兴趣的:(机器学习,深度学习,数学,深度学习)