原始的MAML算法一个很大的挑战是外循环(元更新)需要通过对内循环(梯度自适应)过程进行求导,一般就要求存储和计算高阶导数。这篇论文的核心是利用隐微分方法,求解过程只需要内循环优化的解,而不需要整个内循环优化器的优化过程。
好处:①这样就将元梯度计算(外循环)和内循环优化器的选择解耦,可以任意选择内层优化器;②多步梯度不再有梯度消失或者存储约束
上图可知,MAML算法需要对内循环优化路径进行求导来计算元梯度,一阶MAML简单的将 d ϕ i d θ \frac{d\phi_i}{d\theta} dθdϕi置为 I I I来进行估计;iMAML通过估计local curvature推导出准确的元梯度解析表达式(用内循环的solution而不是对solution的求导来表达元梯度),而不用对整个优化路径进行求导。
这样的好处有:不用存储和求导优化路径,能有效地在内循环中应用多步梯度;整个方法与内优化方法的选择无关,只要能得到内循环优化问题的一个估计解就行。这样可以应用高阶方法甚至不可导的优化方法。
θ M L ∗ : = argmin θ ∈ Θ F ( θ ) ⏞ outer-lever , where F ( θ ) = 1 M ∑ i = 1 M L ( A l g ( θ , D i t r ) ⏞ inner-level , D i test ) \overbrace{\boldsymbol{\theta}_{\mathrm{ML}}^{*}:=\underset{\boldsymbol{\theta} \in \Theta}{\operatorname{argmin}} F(\boldsymbol{\theta})}^{\text{outer-lever}}, \text { where } F(\boldsymbol{\theta})=\frac{1}{M} \sum_{i=1}^{M} \mathcal{L}\left(\overbrace{\mathcal{A} l g\left(\boldsymbol{\theta}, \mathcal{D}_{i}^{\mathrm{tr}}\right)}^{\text {inner-level }}, \mathcal{D}_{i}^{\text {test }}\right) θML∗:=θ∈ΘargminF(θ) outer-lever, where F(θ)=M1i=1∑ML⎝⎛Alg(θ,Ditr) inner-level ,Ditest ⎠⎞
公式中 A l g \mathcal{A} l g Alg代表内循环的算法,输出的是自适应任务的优化参数。为了防止过拟合,可以在内循环过程中加入正则项:
A l g ⋆ ( θ , D i t r ) = arg min ϕ ′ ∈ Φ L ( ϕ ′ , D i t r ) + λ 2 ∣ ∣ ϕ ′ − θ ∣ ∣ 2 \mathcal{A} l g^\star\left(\boldsymbol{\theta}, \mathcal{D}_{i}^{\mathrm{tr}}\right)=\arg\min_{\phi'\in\Phi}\mathcal{L}(\phi',\mathcal{D}^{tr}_{i})+\frac{\lambda}{2}||\phi'-\theta||^2 Alg⋆(θ,Ditr)=argϕ′∈ΦminL(ϕ′,Ditr)+2λ∣∣ϕ′−θ∣∣2
这里 θ \theta θ是我们要求的元参数(即模型初始化),内循环过程中看做一个常量,在外循环中梯度更新求解,内循环过程实际变量是自适应参数 ϕ ′ \phi' ϕ′。 ⋆ \star ⋆表示可准确求解,实际当中使用梯度迭代法只能返回估计的最优值。进一步的双阶段优化问题可改写为:
θ M L ∗ : = argmin θ ∈ Θ F ( θ ) , where F ( θ ) = 1 M ∑ i = 1 M L i ( A l g i ⋆ ( θ ) ) , and A l g i ⋆ ( θ ) : = argmin ϕ ′ ∈ Φ G i ( ϕ ′ , θ ) , where G i ( ϕ ′ , θ ) = L ^ i ( ϕ ′ ) + λ 2 ∥ ϕ ′ − θ ∥ 2 \begin{array}{l}{\boldsymbol{\theta}_{\mathrm{ML}}^{*}:=\underset{\boldsymbol{\theta} \in \Theta}{\operatorname{argmin}} F(\boldsymbol{\theta}), \text { where } F(\boldsymbol{\theta})=\frac{1}{M} \sum_{i=1}^{M} \mathcal{L}_{i}\left(\mathcal{A} l g_{i}^{\star}(\boldsymbol{\theta})\right), \text { and }} \\ {\mathcal{A} l g_{i}^{\star}(\boldsymbol{\theta}):=\underset{\boldsymbol{\phi}^{\prime} \in \Phi}{\operatorname{argmin}} G_{i}\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right), \text { where } G_{i}\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)=\hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}^{\prime}\right)+\frac{\lambda}{2}\left\|\boldsymbol{\phi}^{\prime}-\boldsymbol{\theta}\right\|^{2}}\end{array} θML∗:=θ∈ΘargminF(θ), where F(θ)=M1∑i=1MLi(Algi⋆(θ)), and Algi⋆(θ):=ϕ′∈ΦargminGi(ϕ′,θ), where Gi(ϕ′,θ)=L^i(ϕ′)+2λ∥∥ϕ′−θ∥∥2其中
L i ( ϕ ) : = L ( ϕ , D i test ) , L ^ i ( ϕ ) : = L ( ϕ , D i tr ) , A l g i ( θ ) : = A l g ( θ , D i tr ) \mathcal{L}_{i}(\phi):=\mathcal{L}\left(\phi, \mathcal{D}_{i}^{\text {test }}\right), \quad \hat{\mathcal{L}}_{i}(\phi):=\mathcal{L}\left(\phi, \mathcal{D}_{i}^{\text {tr }}\right), \quad \mathcal{A} l g_{i}(\boldsymbol{\theta}):=\mathcal{A} l g\left(\boldsymbol{\theta}, \mathcal{D}_{i}^{\text {tr }}\right) Li(ϕ):=L(ϕ,Ditest ),L^i(ϕ):=L(ϕ,Ditr ),Algi(θ):=Alg(θ,Ditr )用 d , ∇ d,\nabla d,∇分别表示全导数和偏导数,根据链式法则,我们知道元梯度可写为:
d θ L i ( A l g i ( θ ) ) = d A l g i ( θ ) d θ ∇ ϕ L i ( ϕ ) ∣ ϕ = A l g i ( θ ) = d A l g i ( θ ) d θ ∇ ϕ L i ( A l g i ( θ ) ) d_{\boldsymbol{\theta}}\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta}))=\frac{d\mathcal{A} l g_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}\nabla_\phi\mathcal{L}_i(\phi)|_{\phi=\mathcal{A} l g_{i}(\boldsymbol{\theta})}=\frac{d\mathcal{A} l g_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}\nabla_\phi\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta})) dθLi(Algi(θ))=dθdAlgi(θ)∇ϕLi(ϕ)∣ϕ=Algi(θ)=dθdAlgi(θ)∇ϕLi(Algi(θ))
上式中 ∇ ϕ L i ( A l g i ( θ ) ) = ∇ ϕ L i ( ϕ ) ∣ ϕ = A l g i ( θ ) \nabla_\phi\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta}))=\nabla_\phi\mathcal{L}_i(\phi)|_{\phi=\mathcal{A} l g_{i}(\boldsymbol{\theta})} ∇ϕLi(Algi(θ))=∇ϕLi(ϕ)∣ϕ=Algi(θ)在求解出 A l g i ⋆ ( θ ) \mathcal{A} l g^\star_{i}(\boldsymbol{\theta}) Algi⋆(θ)(利用梯度下降或其他优化方法)后,很容易计算。而 d A l g i ( θ ) d θ \frac{d\mathcal{A} l g_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}} dθdAlgi(θ)的计算比较复杂,直接利用导数传递涉及到高阶导数,且需要记录整个更新过程。将内循环(自适应)过程的结果 ϕ i = A l g i ⋆ \phi_i = \mathcal{A} l g^\star_{i} ϕi=Algi⋆隐式地定义为优化问题的solution。那么可以采用一种不需要考虑优化路径的方法来计算 ϕ i \phi_i ϕi(Lemma 1):
d A l g i ⋆ ( θ ) d θ = ( I + 1 λ ∇ ϕ 2 L ^ i ( ϕ i ) ) − 1 \frac{d\mathcal{A} l g^\star_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}=\left(\boldsymbol{I}+\frac{1}{\lambda}\nabla^2_\phi\hat{\mathcal{L}}_i(\phi_i)\right)^{-1} dθdAlgi⋆(θ)=(I+λ1∇ϕ2L^i(ϕi))−1
证明: ϕ i = A l g i ⋆ \phi_i = \mathcal{A} l g^\star_{i} ϕi=Algi⋆是函数 G i ( ϕ ′ , θ ) = L ^ i ( ϕ ′ ) + λ 2 ∥ ϕ ′ − θ ∥ 2 G_{i}\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)=\hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}^{\prime}\right)+\frac{\lambda}{2}\left\|\boldsymbol{\phi}^{\prime}-\boldsymbol{\theta}\right\|^{2} Gi(ϕ′,θ)=L^i(ϕ′)+2λ∥∥ϕ′−θ∥∥2的最小值的时候满足一阶必要条件,即一阶梯度为0:
∇ ϕ ′ G ( ϕ ′ , θ ) ∣ ϕ ′ = ϕ i = 0 ⟹ ∇ L ^ ( ϕ i ) + λ ( ϕ i − θ ) = 0 ⟹ ϕ i = θ − 1 λ ∇ L ^ ( ϕ i ) \left.\nabla_{\boldsymbol{\phi}^{\prime}} G\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)\right|_{\boldsymbol{\phi}^{\prime}=\boldsymbol{\phi}_i}=0 \Longrightarrow \nabla \hat{\mathcal{L}}(\boldsymbol{\phi}_i)+\lambda(\boldsymbol{\phi}_i-\boldsymbol{\theta})=0 \Longrightarrow \boldsymbol{\phi}_i=\boldsymbol{\theta}-\frac{1}{\lambda} \nabla \hat{\mathcal{L}}(\boldsymbol{\phi}_i) ∇ϕ′G(ϕ′,θ)∣∣ϕ′=ϕi=0⟹∇L^(ϕi)+λ(ϕi−θ)=0⟹ϕi=θ−λ1∇L^(ϕi)上式是常见的隐等式,当倒数存在的时候,上式左右两边同时对 θ \boldsymbol{\theta} θ求导有:
d ϕ i d θ = I − 1 λ ∇ 2 L ^ ( ϕ i ) d ϕ i d θ ⟹ ( I + 1 λ ∇ 2 L ^ ( ϕ i ) ) d ϕ i d θ = I \frac{d \boldsymbol{\phi}_i}{d \boldsymbol{\theta}}=I-\frac{1}{\lambda} \nabla^{2} \hat{\mathcal{L}}(\boldsymbol{\phi}_i) \frac{d \boldsymbol{\phi}_i}{d \boldsymbol{\theta}} \Longrightarrow\left(I+\frac{1}{\lambda} \nabla^{2} \hat{\mathcal{L}}(\boldsymbol{\phi_i})\right) \frac{d \boldsymbol{\phi_i}}{d \boldsymbol{\theta}}=I dθdϕi=I−λ1∇2L^(ϕi)dθdϕi⟹(I+λ1∇2L^(ϕi))dθdϕi=I
上式中 d ϕ i d θ = d A l g i ⋆ ( θ ) d θ \frac{d \boldsymbol{\phi}_i}{d \boldsymbol{\theta}}=\frac{d\mathcal{A} l g^\star_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}} dθdϕi=dθdAlgi⋆(θ)在计算中有两个困难,首先 A l g i ⋆ \mathcal{A} l g^\star_{i} Algi⋆是精确的解,而通过内循环优化得到的往往只是估计解;除此之外,计算还涉及到求逆和二阶导,这对深度神经网络是很难的。本文采取估计的方法对上式求解进行简化,核心公式为:
∥ g i − ( I + 1 λ ∇ ϕ 2 L ^ i ( ϕ i ) ) − 1 ∇ ϕ L i ( ϕ i ) ∥ ≤ δ ′ \left\|\boldsymbol{g}_{i}-\left(I+\frac{1}{\lambda} \nabla_{\boldsymbol{\phi}}^{2} \hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}_{i}\right)\right)^{-1} \nabla_{\boldsymbol{\phi}} \mathcal{L}_{i}\left(\boldsymbol{\phi}_{i}\right)\right\| \leq \delta^{\prime} ∥∥∥∥∥gi−(I+λ1∇ϕ2L^i(ϕi))−1∇ϕLi(ϕi)∥∥∥∥∥≤δ′式中 g i \boldsymbol{g}_i gi即为对元梯度 d θ L i ( A l g i ( θ ) ) d_{\boldsymbol{\theta}}\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta})) dθLi(Algi(θ))的估计, ϕ i \boldsymbol{\phi}_i ϕi是对最优值 A l g i ⋆ \mathcal{A} l g^\star_{i} Algi⋆的估计,利用梯度优化迭代法什么的求解。那么进一步的上述 g i \boldsymbol{g}_i gi的求解可转化成一个二次型优化问题:
min w 1 2 w ⊤ ( I + 1 λ ∇ ϕ 2 L ^ i ( ϕ i ) ) w − w ⊤ ∇ ϕ L i ( ϕ i ) \min _{\boldsymbol{w}} \frac{1}{2}\boldsymbol{w}^{\top}\left(\boldsymbol{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\phi}}^{2} \hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}_{i}\right)\right) \boldsymbol{w}-\boldsymbol{w}^{\top} \nabla_{\boldsymbol{\phi}} \mathcal{L}_{i}\left(\boldsymbol{\phi}_{i}\right) wmin21w⊤(I+λ1∇ϕ2L^i(ϕi))w−w⊤∇ϕLi(ϕi)这样可以利用共轭梯度法快速求解。过程中只需要计算 ∇ 2 L i ^ ( ϕ i ) v \nabla^2\hat{\mathcal{L}_i}(\boldsymbol{\phi}_i)\boldsymbol{v} ∇2Li^(ϕi)v( v \boldsymbol{v} v是共轭梯度)