score / informant 定义为对数似然函数关于参数的梯度:
s ( θ ) ≡ ∂ log L ( θ ) ∂ θ s(\theta) \equiv \frac{\partial\log{\mathcal{L}(\theta)}}{\partial\theta} s(θ)≡∂θ∂logL(θ)
其中 L ( θ ) \mathcal{L}(\theta) L(θ)即为似然函数,可扩写为 L ( θ ∣ x ) \mathcal{L}(\theta |x) L(θ∣x),其中 x x x为观测到的数据, x x x从采样域 X \mathcal{X} X中产生
在某一特定点的 s s s 函数指明了该点处对数似然函数的陡峭程度(steepness),或者是函数值对参数发生无穷小量变化的敏感性。
如果对数似然函数定义在连续实数值的参数空间中,那么它的函数值将在(局部)极大值与极小值点消失。这一性质通常用于极大似然估计中(maximum likelihood estimation, MLE),来寻找使得似然函数值极大的参数值。
注意 L ( θ ∣ x ) \mathcal{L}(\theta |x) L(θ∣x)中竖线前后的字母 θ ∣ x \theta|x θ∣x, x x x为随机变量,在这里则是一个定值,意为采样后的观测值,而 θ \theta θ则为自变量,意为参数模型中的参数
当(假设) θ \theta θ位于正确值时,我们可以通过 θ \theta θ推导 x x x,也就是 f ( x ∣ θ ) f(x|\theta) f(x∣θ) ,为一概率密度函数,意为当模型参数为 θ \theta θ时,采样到 x x x的概率
从两个角度得到了对同一事实的论证,因此可写作 f ( x ∣ θ ) = L ( θ ∣ x ) f(x|\theta) = \mathcal{L}(\theta | x) f(x∣θ)=L(θ∣x)
首先,来分析 s s s的数学期望,这里讨论的问题是:当参数取值为 θ \theta θ时, s ∣ θ s|\theta s∣θ的数学期望
从直观上分析,当参数位于真实(最佳)参数点时,似然函数有其极大值(考虑极大似然估计的定义),因此为一极值点,所以该点梯度为 0 0 0,即 E [ s ∣ θ ] = 0 \mathbb{E}[s|\theta]= 0 E[s∣θ]=0
下面进行公式分析:
首先要明确,该期望是 s s s函数关于什么随机变量的期望。从上面的讨论中可以得到,该问题中唯一的随机变量是采样观测值 x x x,它的采样概率是 f ( x ∣ θ ) f(x|\theta) f(x∣θ)
注意:
f ( x ) ∂ log f ( x ) ∂ x = f ( x ) 1 f ( x ) ∂ f ( x ) ∂ x = ∂ f ( x ) ∂ x \begin{aligned} & f(x) \frac{\partial\log{f(x)}}{\partial{x}} \\ = & {f(x)} \frac{1}{f(x)} \frac{\partial{f(x)}}{\partial{x}} \\ = & \frac{\partial{f(x)}}{\partial{x}} \end{aligned} ==f(x)∂x∂logf(x)f(x)f(x)1∂x∂f(x)∂x∂f(x)
所以:
E [ s ∣ θ ] = ∫ X f ( x ∣ θ ) ⋅ s ⋅ d x = ∫ X f ( x ∣ θ ) ∂ log L ( θ ∣ x ) ∂ θ d x = ∫ X f ( x ∣ θ ) ∂ log f ( x ∣ θ ) ∂ θ d x = ∫ X ∂ f ( x ∣ θ ) ∂ θ d x = ∂ ∂ x ∫ X f ( x ∣ θ ) d x = ∂ ∂ x 1 = 0 ■ \begin{aligned} \mathbb{E}[s|\theta] & = \int_{\mathcal{X}}f(x|\theta)\cdot{}s\cdot{}\mathrm{d}x \\ & = \int_{\mathcal{X}}f(x|\theta) \frac{\partial\log{\mathcal{L}(\theta|x)}}{\partial\theta} \mathrm{d}x \\ & = \int_{\mathcal{X}}f(x|\theta) \frac{\partial\log{f(x|\theta)}}{\partial\theta} \mathrm{d}x \\ & = \int_{\mathcal{X}} \frac{\partial f(x|\theta)}{\partial{\theta}} \mathrm{d}x \\ & = \frac{\partial}{\partial{x}} \int_{\mathcal{X}}f(x|\theta)\mathrm{d}x \\ & = \frac{\partial}{\partial{x}} 1 \\ & = 0\qquad\blacksquare \\ \end{aligned} E[s∣θ]=∫Xf(x∣θ)⋅s⋅dx=∫Xf(x∣θ)∂θ∂logL(θ∣x)dx=∫Xf(x∣θ)∂θ∂logf(x∣θ)dx=∫X∂θ∂f(x∣θ)dx=∂x∂∫Xf(x∣θ)dx=∂x∂1=0■
因此得证: E [ s ∣ θ ] = 0 \mathbb{E}[s|\theta]= 0 E[s∣θ]=0
Fisher信息(Fisher information),或简称为信息(information)是一种衡量信息量的指标
假设我们想要建模一个随机变量 x x x 的分布,用于建模的参数是 θ \theta θ,那么Fisher信息测量了 x x x 携带的对于 θ \theta θ 的信息量
所以,当我们固定 θ \theta θ 值,以 x x x 为自变量,Fisher 信息应当指出这一 x x x 值可贡献给 θ \theta θ 多少信息量
比如说,某一 θ \theta θ 点附近的函数平面非常陡峭(有一极值峰值),那么我们不需要采样多少 x x x 即可做出比较好的估计,也就是采样点 x x x 的Fisher 信息量较高。反之,若某一 θ \theta θ 附近的函数平面连续且平缓,那么我们需要采样很多点才能做出比较好的估计,也就是 Fisher 信息量较低。
从这一直观定义出发,我们可以联想到随机变量的方差,因此对于一个(假设的)真实参数 θ \theta θ, s s s 函数的 Fisher 信息定义为 s s s 函数的方差
I ( θ ) = E [ ( ∂ ∂ θ log f ( x ∣ θ ) ) 2 ∣ θ ] = ∫ ( ∂ ∂ θ log f ( x ∣ θ ) ) 2 f ( x ; θ ) d x \begin{aligned} \mathcal{I} (\theta) & =\mathbb{E}\left[\left.\left({\frac {\partial }{\partial \theta }}\log f(x|\theta )\right)^{2}\right|\theta \right] \\ & = \int \left({\frac {\partial }{\partial \theta }}\log f(x|\theta )\right)^{2}f(x;\theta )\mathrm{d}x \end{aligned} I(θ)=E[(∂θ∂logf(x∣θ))2∣∣∣∣∣θ]=∫(∂θ∂logf(x∣θ))2f(x;θ)dx
此外,如果 log f ( x ∣ θ ) \log f(x|\theta) logf(x∣θ) 对于 θ \theta θ 二次可微,那么 Fisher 信息还可以写作
I ( θ ) = − E [ ∂ 2 ∂ 2 θ log f ( x ∣ θ ) ∣ θ ] \mathcal{I}(\theta) = -\mathbb{E}\left[\left.{\frac {\partial^2}{\partial^2 \theta }}\log f(x|\theta )\right|\theta \right] I(θ)=−E[∂2θ∂2logf(x∣θ)∣∣∣∣θ]
证明如下:
∵ 0 = E [ s ∣ θ ] ∴ 0 = ∂ ∂ θ E [ s ∣ θ ] = ∂ ∂ θ ∫ X f ( x ∣ θ ) ∂ log L ( θ ∣ x ) ∂ θ d x = ∫ X ∂ ∂ θ ∂ log L ( θ ∣ x ) ∂ θ f ( x ∣ θ ) d x ▹ use chain rule = ∫ X { ∂ 2 log L ( θ ∣ x ) ∂ 2 θ f ( x ∣ θ ) + ∂ f ( x ∣ θ ) ∂ θ ∂ log L ( θ ∣ x ) ∂ θ } d x = ∫ X ∂ 2 log L ( θ ∣ x ) ∂ 2 θ f ( x ∣ θ ) d x ⏟ A + ∫ X ∂ L ( θ ∣ x ) ∂ θ ∂ log L ( θ ∣ x ) ∂ θ d x ⏟ B A = E [ ∂ 2 log L ( θ ∣ x ) ∂ 2 θ ∣ θ ] B = ∫ X ∂ L ( θ ∣ x ) ∂ θ ∂ log L ( θ ∣ x ) ∂ θ d x = ∫ X ∂ log L ( θ ∣ x ) ∂ θ L ( θ ∣ x ) ∂ log L ( θ ∣ x ) ∂ θ d x = ∫ X ( ∂ log L ( θ ∣ x ) ∂ θ ) 2 f ( x ∣ θ ) d x = E [ ( ∂ log L ( θ ∣ x ) ∂ θ ) 2 ∣ θ ] ∵ A + B = 0 ∴ E [ ∂ 2 log L ( θ ∣ x ) ∂ 2 θ ∣ θ ] + E [ ( ∂ log L ( θ ∣ x ) ∂ θ ) 2 ∣ θ ] = 0 \begin{aligned} \because 0 & = \mathbb{E}[s|\theta] \\ \\ \therefore 0 & = \frac{\partial}{\partial\theta} \mathbb{E}[s|\theta] \\ & = \frac{\partial}{\partial\theta}\int_{\mathcal{X}} f(x|\theta) \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x \\ & = \int_{\mathcal{X}} \frac{\partial}{\partial\theta} \boxed{\frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} f(x|\theta)}\ \mathrm{d}x \quad{\triangleright\ \textrm{use chain rule}}\\ & = \int_{\mathcal{X}} \left\{ \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}f(x|\theta) + \frac{\partial f(x|\theta)}{\partial\theta} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}}\right\} \mathrm{d}x \\ & = \underbrace{\int_{\mathcal{X}} \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}f(x|\theta) \mathrm{d}x }_\mathbf{A} + \underbrace{\int_{\mathcal{X}}\frac{\partial \mathcal{L}(\theta|x)}{\partial\theta} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x}_{\mathbf{B}} \\ \\ \mathbf{A} &= \mathbb{E}\left[\left. \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}\right| \theta \right] \\ \mathbf{B} &= \int_{\mathcal{X}} \red{\frac{\partial \mathcal{L}(\theta|x)}{\partial\theta}} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x \\ &= \int_{\mathcal{X}}\red{\frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}}\mathcal{L}(\theta|x)} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x\\ &= \int_{\mathcal{X}} \left(\frac{ \partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \right)^2 f(x|\theta)\mathrm{d}x \\ &= \mathbb{E}\left[\left. \left(\frac{ \partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \right)^2 \right| \theta \right] \\ \\ & \because \mathbf{A}+\mathbf{B} = 0 \\ & \therefore \mathbb{E}\left[\left. \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}\right| \theta \right] + \mathbb{E}\left[\left. \left(\frac{ \partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \right)^2 \right| \theta \right] = 0 \end{aligned} ∵0∴0AB=E[s∣θ]=∂θ∂E[s∣θ]=∂θ∂∫Xf(x∣θ)∂θ∂logL(θ∣x)dx=∫X∂θ∂∂θ∂logL(θ∣x)f(x∣θ) dx▹ use chain rule=∫X{∂2θ∂2logL(θ∣x)f(x∣θ)+∂θ∂f(x∣θ)∂θ∂logL(θ∣x)}dx=A ∫X∂2θ∂2logL(θ∣x)f(x∣θ)dx+B ∫X∂θ∂L(θ∣x)∂θ∂logL(θ∣x)dx=E[∂2θ∂2logL(θ∣x)∣∣∣∣θ]=∫X∂θ∂L(θ∣x)∂θ∂logL(θ∣x)dx=∫X∂θ∂logL(θ∣x)L(θ∣x)∂θ∂logL(θ∣x)dx=∫X(∂θ∂logL(θ∣x))2f(x∣θ)dx=E[(∂θ∂logL(θ∣x))2∣∣∣∣∣θ]∵A+B=0∴E[∂2θ∂2logL(θ∣x)∣∣∣∣θ]+E[(∂θ∂logL(θ∣x))2∣∣∣∣∣θ]=0
假设数据集被划分为两个任务 Σ = { A , B } \Sigma = \{\mathcal{A, B}\} Σ={A,B},网络参数为 θ \theta θ
学习任务为最大化后验概率
arg max θ P ( θ ∣ Σ ) = arg max θ log P ( θ ∣ Σ ) = arg min θ l ( θ ) \begin{aligned} & \argmax_{\theta}P(\theta | \Sigma) \\ = & \argmax_{\theta} \log P(\theta | \Sigma) \\ = & \argmin_{\theta} l(\theta) \end{aligned} ==θargmaxP(θ∣Σ)θargmaxlogP(θ∣Σ)θargminl(θ)
其中 l ( θ ) l(\theta) l(θ)定义为训练 loss
考虑任务训练顺序 A ⇒ B \mathcal{A} \Rightarrow \mathcal{B} A⇒B
log P ( θ ∣ Σ ) = log P ( θ ∣ A , B ) = log P ( θ , A , B ) − log P ( A , B ) = log P ( B ∣ θ , A ) + log P ( θ , A ) − log P ( B ∣ A ) − log P ( A ) = log P ( B ∣ θ ) + log P ( θ ∣ A ) + log P ( A ) − log P ( B ) − log P ( A ) ▹ A, B i.i.d = log P ( B ∣ θ ) ⏟ l o s s o n B + log P ( θ ∣ A ) ⏟ unknown − log P ( B ) ⏟ constant \begin{aligned} & \log P(\theta | \Sigma) \\ &= \log P(\theta | A, B) \\ &= \log P(\theta, A, B) - \log P(A, B) \\ &= \log P(B|\theta, A) + \log P(\theta, A) - \log P(B|A) - \log P(A) \\ &= \log P(B|\theta) + \log P(\theta | A) + \log P(A) - \log P(B) - \log P(A) \qquad \triangleright \textrm{A, B i.i.d } \\ &= \underbrace{\log P(B|\theta)}_{\mathrm{loss\ on\ \mathcal{B}}} + \underbrace{\log P(\theta|A)}_{\textrm{unknown}} - \underbrace{\log P(B)}_{\textrm{constant}} \\ \end{aligned} logP(θ∣Σ)=logP(θ∣A,B)=logP(θ,A,B)−logP(A,B)=logP(B∣θ,A)+logP(θ,A)−logP(B∣A)−logP(A)=logP(B∣θ)+logP(θ∣A)+logP(A)−logP(B)−logP(A)▹A, B i.i.d =loss on B logP(B∣θ)+unknown logP(θ∣A)−constant logP(B)
其中后验概率 log P ( θ ∣ A ) \log P(\theta | A) logP(θ∣A) 不易得到,因此使用拉普拉斯近似进行分析
在训练任务 B \mathcal{B} B 之前,网络已经在任务 A \mathcal{A} A 上收敛,设网络此时的参数为 θ A ∗ \theta_{A}^* θA∗,为在任务 A \mathcal{A} A 上拟合得到的参数,设函数 f ( θ ) = log P ( θ ∣ A ) f(\theta) = \log P(\theta | A) f(θ)=logP(θ∣A)
对 f ( θ ) f(\theta) f(θ) 在 θ = θ A ∗ \theta = \theta_{A}^* θ=θA∗ 做泰勒展开:
f ( θ ) = f ( θ A ∗ ) + ∂ f ( θ ) ∂ θ ∣ θ A ∗ ⏟ = 0 ( θ − θ A ∗ ) + 1 2 ( θ − θ A ∗ ) T ∂ 2 f ( θ ) ∂ 2 θ ∣ θ A ∗ ( θ − θ A ∗ ) + ⋯ ≈ f ( θ A ∗ ) + 1 2 ( θ − θ A ∗ ) T ∂ 2 f ( θ ) ∂ 2 θ ∣ θ A ∗ ( θ − θ A ∗ ) \begin{aligned} f(\theta) &= f(\theta_{A}^*) + \underbrace{\left.\frac{\partial f(\theta)}{\partial\theta}\right|_{\theta_A^*}}_{=0}(\theta - \theta_{A}^*) + \frac{1}{2}(\theta - \theta_{A}^*)^T \left.\frac{\partial^2f(\theta)}{\partial^2\theta}\right|_{\theta_A^*} (\theta - \theta_{A}^*) + \cdots \\ &\approx f(\theta_{A}^*) + \frac{1}{2}(\theta - \theta_{A}^*)^T \left.\frac{\partial^2f(\theta)}{\partial^2\theta}\right|_{\theta_A^*} (\theta - \theta_{A}^*) \\ \end{aligned} f(θ)=f(θA∗)+=0 ∂θ∂f(θ)∣∣∣∣θA∗(θ−θA∗)+21(θ−θA∗)T∂2θ∂2f(θ)∣∣∣∣θA∗(θ−θA∗)+⋯≈f(θA∗)+21(θ−θA∗)T∂2θ∂2f(θ)∣∣∣∣θA∗(θ−θA∗)
将 f ( θ ) = log P ( θ ∣ A ) f(\theta) = \log P(\theta | A) f(θ)=logP(θ∣A) 代入:
log P ( θ ∣ A ) = log P ( θ A ∗ ∣ A ) + 1 2 ( θ − θ A ∗ ) T ∂ 2 log P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ( θ − θ A ∗ ) = log P ( θ A ∗ ∣ A ) + 1 2 ( θ − θ A ∗ ) T { − [ − ∂ 2 log P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ] − 1 } − 1 ( θ − θ A ∗ ) P ( θ ∣ A ) = exp [ Δ + 1 2 ( θ − θ A ∗ ) T { − [ − ∂ 2 log P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ] − 1 } − 1 ( θ − θ A ∗ ) ] = ϵ exp [ − 1 2 ( θ − θ A ∗ ) T { [ − ∂ 2 log P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ] − 1 } − 1 ⏟ Σ − 1 ( θ − θ A ∗ ) ] where: Δ = log P ( θ A ∗ ∣ A ) ϵ = exp Δ \begin{aligned} \log P(\theta|A) &= \log P(\theta_A^* | A) + \frac{1}{2}(\theta - \theta_{A}^*)^T \left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*} (\theta - \theta_{A}^*) \\ &= \log P(\theta_A^* | A) + \frac{1}{2}(\theta - \theta_{A}^*)^T\left\{-\left[-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right]^{-1}\right\}^{-1} (\theta - \theta_{A}^*) \\ P(\theta|A) &= \exp{\left[ \Delta + \frac{1}{2}(\theta - \theta_{A}^*)^T\left\{-\left[-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right]^{-1}\right\}^{-1} (\theta - \theta_{A}^*) \right]} \\ &= \epsilon\exp{\left[-\frac{1}{2}(\theta - \theta_{A}^*)^T \underbrace{\left\{\left[-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right]^{-1}\right\}^{-1}}_{\Sigma^{-1}} (\theta - \theta_{A}^*) \right]} \\ \textbf{where:} & \\ \Delta &= \log P(\theta_A^* | A) \\ \epsilon &= \exp{\Delta} \end{aligned} logP(θ∣A)P(θ∣A)where:Δϵ=logP(θA∗∣A)+21(θ−θA∗)T∂2θ∂2logP(θ∣A)∣∣∣∣θA∗(θ−θA∗)=logP(θA∗∣A)+21(θ−θA∗)T⎩⎨⎧−[−∂2θ∂2logP(θ∣A)∣∣∣∣θA∗]−1⎭⎬⎫−1(θ−θA∗)=exp⎣⎢⎡Δ+21(θ−θA∗)T⎩⎨⎧−[−∂2θ∂2logP(θ∣A)∣∣∣∣θA∗]−1⎭⎬⎫−1(θ−θA∗)⎦⎥⎤=ϵexp⎣⎢⎢⎢⎢⎡−21(θ−θA∗)TΣ−1 ⎩⎨⎧[−∂2θ∂2logP(θ∣A)∣∣∣∣θA∗]−1⎭⎬⎫−1(θ−θA∗)⎦⎥⎥⎥⎥⎤=logP(θA∗∣A)=expΔ
观察形式可得:
P ( θ ∣ A ) ∼ N ( θ A ∗ , ( − ∂ 2 log P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ) − 1 ) P(\theta | \mathcal{A}) \sim \mathcal{N}\left(\theta_{\mathcal{A}}^*,\left(-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right)^{-1}\right) P(θ∣A)∼N⎝⎛θA∗,(−∂2θ∂2logP(θ∣A)∣∣∣∣θA∗)−1⎠⎞
其中协方差矩阵项正是第一部分讨论的Fisher信息矩阵,记做 I A \mathbf{I}_{\mathcal{A}} IA,则有
P ( θ ∣ A ) ∼ N ( θ A ∗ , [ I A ] − 1 ) P(\theta | \mathcal{A}) \sim \mathcal{N}\left(\theta_{\mathcal{A}}^*,\left[\mathbf{I}_{\mathcal{A}}\right]^{-1} \right) P(θ∣A)∼N(θA∗,[IA]−1)
另外,EWC是以一个参数的视角出发的,因此Fisher信息矩阵只需要对角线元素,其余计算出来的结果可以置0,所以有:
P ( θ ∣ A ) = 1 ( 2 π ) k ∣ Σ ∣ exp { − 1 2 ( θ − θ A ∗ ) T Σ − 1 ( θ − θ A ∗ ) } log P ( θ i ∣ A ) = − 1 2 ( θ i − [ θ i ] A ∗ ) 2 ∗ [ Σ − 1 ] i i = − [ I A ] i i ( θ i − [ θ i ] A ∗ ) 2 2 \begin{aligned} P(\theta | A) &= \frac{1}{\sqrt{(2\pi)^k |\Sigma|}} \exp\{-\frac{1}{2}(\theta - \theta_A^*)^T\Sigma^{-1}(\theta - \theta_{A}^*)\}\\ \\ \log P(\theta_{i} | A) &= -\frac{1}{2}(\theta_i - [\theta_i]_A^*)^2 * [\Sigma^{-1}]_{ii} \\ &= -\left[\mathbf{I}_{\mathcal{A}}\right]_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \end{aligned} P(θ∣A)logP(θi∣A)=(2π)k∣Σ∣1exp{−21(θ−θA∗)TΣ−1(θ−θA∗)}=−21(θi−[θi]A∗)2∗[Σ−1]ii=−[IA]ii2(θi−[θi]A∗)2
所以,所有参数的EWC Loss可定义为:
l EWC = − ∑ i = 1 #Params [ I A ] i i ( θ i − [ θ i ] A ∗ ) 2 2 \begin{aligned} l_{\textbf{EWC}} = -\sum_{i=1}^{\textrm{\#Params}} \left[\mathbf{I}_{\mathcal{A}}\right]_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \end{aligned} lEWC=−i=1∑#Params[IA]ii2(θi−[θi]A∗)2
将上述内容代入总优化目标:
l ( θ ) = log P ( θ ∣ Σ ) = log P ( B ∣ θ ) + log P ( θ ∣ A ) − log P ( B ) ⇒ l CE ( θ ∣ B ) + l EWC ( θ ∣ A ) \begin{aligned} l(\theta) &= \log P(\theta | \Sigma) \\ &= \log P(B|\theta) + \log P(\theta|A) - \log P(B) \\ &\Rightarrow l_{\textbf{CE}} (\theta | \mathcal{B}) + l_{\textbf{EWC}}(\theta | \mathcal{A}) \end{aligned} l(θ)=logP(θ∣Σ)=logP(B∣θ)+logP(θ∣A)−logP(B)⇒lCE(θ∣B)+lEWC(θ∣A)
定义超参数 λ \lambda λ 进行稳定性-可塑性权衡
l ( θ ) = l CE ( θ ∣ B ) + λ ⋅ l EWC ( θ ∣ A ) \begin{aligned} l(\theta) &= l_{\textbf{CE}} (\theta | \mathcal{B}) + \lambda \cdot l_{\textbf{EWC}}(\theta | \mathcal{A}) \end{aligned} l(θ)=lCE(θ∣B)+λ⋅lEWC(θ∣A)
因此优化目标为:
arg min θ l ( θ ) = arg min θ { l CE ( θ ∣ B ) + λ ⋅ l EWC ( θ ∣ A ) } = arg min θ { l CE ( θ ∣ B ) − λ 2 ∑ i = 1 #Params [ I A ] i i ( θ i − [ θ i ] A ∗ ) 2 2 } \begin{aligned} \argmin_{\theta}l(\theta) &= \argmin_{\theta}\left\{l_{\textbf{CE}} (\theta | \mathcal{B}) + \lambda \cdot l_{\textbf{EWC}}(\theta | \mathcal{A})\right\} \\ &= \argmin_{\theta} \left\{ l_{\textbf{CE}} (\theta | \mathcal{B}) - \frac{\lambda}{2}\sum_{i=1}^{\textrm{\#Params}} \left[\mathbf{I}_{\mathcal{A}}\right]_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \right\} \end{aligned} θargminl(θ)=θargmin{lCE(θ∣B)+λ⋅lEWC(θ∣A)}=θargmin{lCE(θ∣B)−2λi=1∑#Params[IA]ii2(θi−[θi]A∗)2}
将训练过程进行划分:
最后一个问题,如何使用 A \mathcal{A} A 计算 I A \mathbf{I}_{\mathcal{A}} IA
考虑定义
I ( θ ) = E [ ( ∂ ∂ θ log f ( x ∣ θ ) ) 2 ∣ θ ] \mathcal{I}(\theta) = \mathbb{E}\left[\left.\left({\frac {\partial }{\partial \theta }}\log f(x|\theta )\right)^{2}\right|\theta \right] \\ I(θ)=E[(∂θ∂logf(x∣θ))2∣∣∣∣∣θ]
可以通过计算梯度的平方来获得每一个参数的 Fisher 信息矩阵项:
I ( θ ) = 1 N ∑ ( x , y ) i ∈ A ( ∂ l LL ( θ ∣ ( x , y ) i ) ∂ θ ⏟ Gradient ) 2 \mathcal{I}(\theta) = \frac{1}{N} \sum_{(x,y)_{i} \in \mathcal{A}}\left(\underbrace{\frac{\partial{l_{\textbf{LL}}(\theta|(x, y)_{i})}}{\partial\theta}}_{\textrm{Gradient}}\right)^2 I(θ)=N1(x,y)i∈A∑⎝⎜⎛Gradient ∂θ∂lLL(θ∣(x,y)i)⎠⎟⎞2
具体来说,可以向模型逐个喂入样本,并计算损失函数,使用神经网络框架自动计算梯度。对于每个参数,累加所有的梯度,最后除以样本数量,即可得到对应参数的 Fisher 信息矩阵项
需要注意的是,当使用 nn.CrossEntrypyLoss
或 nn.NLLLoss
时,由于其中对于 Log-Likelihood \textrm{Log-Likelihood} Log-Likelihood使用了相反数处理,使用该类损失函数得到的矩阵是真实 Fisher 信息矩阵的相反数,即 I ^ ( θ ) = − I ( θ ) \hat{\mathcal{I}}(\theta) = -\mathcal{I}(\theta) I^(θ)=−I(θ),在计算 loss 时要记得将减号改成加号,即
arg min θ l ( θ ) = arg min θ { l CE ( θ ∣ B ) + λ 2 ∑ i = 1 #Params [ I A ] ^ i i ( θ i − [ θ i ] A ∗ ) 2 2 } \argmin_{\theta}l(\theta) = \argmin_{\theta} \left\{ l_{\textbf{CE}} (\theta | \mathcal{B}) + \frac{\lambda}{2}\sum_{i=1}^{\textrm{\#Params}} \red{\hat{\left[\mathbf{I}_{\mathcal{A}}\right]}}_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \right\} θargminl(θ)=θargmin{lCE(θ∣B)+2λi=1∑#Params[IA]^ii2(θi−[θi]A∗)2}
全文完
版权声明:自由转载-非商用-禁止演绎-保持署名 4.0 国际许可协议