Paper : Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
Code : official
作者根据元学习(meta learning)的表达式提出了MAML算法用来进行元知识的梯度下降,使用一阶近似的方法来避免计算损失函数的二阶导,并在小样本学习任务(few-shot learning)上取得了SOTA的成绩。作者强调MAML算法具有模型无关性,可以适用于任何基于梯度下降优化的模型上。并给出了MAML与监督学习的结合和MAML与强化学习结合的算法,强调了算法的通用性。
在此blog中,Meta-Learning 采用的是 Meta-Learning in Neural Networks: A Survey 一文中的形式化定义方式,之后的话博主会将这篇论文的blog编出来。
元学习的直观理解是 “Learning to learn”,也就是说通过多个任务的表现来改善学习算法本身。考虑对比常规的机器学习,常规监督学习的形式化表示如下:
给定训练集 D = { ( x 1 , y 1 ) , . . . ( x N , y N ) } \mathcal D = \{(x_1,y_1),...(x_N,y_N)\} D={(x1,y1),...(xN,yN)},希望找到一个预测模型 y ^ = f θ ( x ) \widehat y = f_{\theta}(x) y =fθ(x) 在训练集上的最优参数解 θ ∗ \theta^* θ∗,即通过下式求解
θ ∗ = arg min θ ( D ; θ , ω ) \theta^* = \arg\min_{\theta} \mathcal(\mathcal D;\theta,\omega) θ∗=argθmin(D;θ,ω)
这里使用 ω \omega ω 表示该解对某些因素的依赖性,例如针对 θ \theta θ 的优化器选择或针对 f f f 的模型的选择,常见的 ω \omega ω 包括优化器的初始化(SGD的步长等), f f f 模型的参数初始化方法,正则化强度等等。 ω \omega ω 被称为是元知识(meta-knowledge),对于常规的机器学习任务来说,元知识是被人为设定的。元学习就是对元知识进行学习和优化。
学习元知识的过程形式化的表示为如下优化问题的求解过程
min ω E T ∼ p ( T ) L ( D , ω ) \min_\omega \mathbb{E}_{\mathcal{T}\sim p(\mathcal T)}\mathcal L(\mathcal D,\omega) ωminET∼p(T)L(D,ω)
其中 T = { D , L } \mathcal T = \{\mathcal D,\mathcal L\} T={D,L} 表示一个常规机器学习任务,假定多个常规任务都是从某个任务分布 p ( T ) p(\mathcal T) p(T) 中采样出来的。我们希望学到的是跨任务的元知识 ω \omega ω ,这些知识可以泛化到一个之前没有遇到过的数据集上,有助于模型在小样本数据集上进行学习。形式化的表述meta-training过程:
给定一个包含 M 个学习任务的数据集 D = { ( D train ( i ) , D val ( i ) ) } \mathbb D = \{(\mathcal D^{(i)}_\text{train},\mathcal D^{(i)}_\text{val})\} D={(Dtrain(i),Dval(i))},求解问题可以表示为
ω ∗ = arg max ω log p ( ω ∣ D ) \omega^* = \arg \max_{\omega} \log p(\omega|\mathbb D) ω∗=argωmaxlogp(ω∣D)
在meta-testing的过程中,首先需要根据元知识进行学习,然后再进行模型的评估,形式化表述为:
给定某训练阶段不可知的任务 j \mathcal j j,测试模型的参数定义为
θ ∗ ( j ) = arg max θ log p ( θ ∣ ω ∗ , D train ( j ) ) {\theta^{*}}^{(j)} = \arg \max_{\theta} \log p(\theta|\omega^*,\mathcal D^{(j)}_\text{train}) θ∗(j)=argθmaxlogp(θ∣ω∗,Dtrain(j))
从双层优化问题的角度来理解 meta-learning,meta-training可以形式化的表示为下式
ω ∗ = arg min ω ∑ i = 1 M L meta ( θ ∗ ( i ) ( ω ) , ω , D val ( i ) ) s.t. θ ∗ ( i ) ( ω ) = arg min θ L task ( θ , ω , D train ( i ) ) \\\omega^* = \arg\min_\omega \sum_{i=1}^M \mathcal L^\text{meta}({\theta^*}^{(i)}(\omega),\omega,\mathcal D_\text{val}^{(i)}) \\\text{s.t. }{\theta^*}^{(i)}(\omega) = \arg\min_{\theta} \mathcal L^\text{task}(\theta,\omega,\mathcal D_\text{train}^{(i)}) ω∗=argωmini=1∑MLmeta(θ∗(i)(ω),ω,Dval(i))s.t. θ∗(i)(ω)=argθminLtask(θ,ω,Dtrain(i))
其中 L task \mathcal L^\text{task} Ltask 和 L meta \mathcal L^\text{meta} Lmeta 分别对应内层和外层的优化目标(损失函数)。
对于few-shot learning来说,一个常见的术语是 N-way K-shot classification,表示对于分类任务,类别总数有 N 个,每个类下面有 K 个样本。
MAML算法的前提是,存在对于模型来说存在某些参数初始化,比其他的初始化方法具有更好的迁移性,更适合做迁移学习。对于MAML算法来说,元知识 ω \omega ω 表示模型的初始化参数,想要解决的问题是小样本学习的问题。小样本集意味着不能在复杂的模型上进行多轮训练,不然会产生overfit问题。MAML通过元学习的方法学到一种对新任务损失函数敏感的初始化方法,使得模型在初始化后经过较少的epoch就可以finetune到一个比较良好的表现上。
我们的目标是得到一个初始化模型参数可以经过较少的epoch获得一个良好的表现,因此,为了简化起见,假定 θ ∗ ( i ) {\theta^*}^{(i)} θ∗(i) 表示进行了一步梯度下降的结果,即
θ ∗ ( i ) = ω − ε ▽ ω L T i ( f ω ) {\theta^*}^{(i)} = \omega - \varepsilon \triangledown_{\omega} \mathcal L_{\mathcal{T}_i}(f_\omega) θ∗(i)=ω−ε▽ωLTi(fω)
而外层的优化目标为
min ω ∑ T i ∼ p ( T ) L T i ( f θ ∗ ( i ) ) \min_\omega \sum_{\mathcal T_i\sim p(\mathcal T)}\mathcal L_{\mathcal T_i}(f_{{\theta^*}^{(i)}}) ωminTi∼p(T)∑LTi(fθ∗(i))
使用SGD算法优化元知识 ω \omega ω ,即
ω ← ω − ϵ ▽ ω ∑ T i ∼ p ( T ) L T i ( f θ ∗ ( i ) ) \omega\leftarrow \omega -\epsilon \triangledown_{\omega} \sum_{\mathcal T_i\sim p(\mathcal T)}\mathcal L_{\mathcal T_i}(f_{{\theta^*}^{(i)}}) ω←ω−ϵ▽ωTi∼p(T)∑LTi(fθ∗(i))
考虑将 ω \omega ω 和 θ ∗ ( i ) {\theta^*}^{(i)} θ∗(i) 都表示为向量,有
▽ ω L ( f θ ∗ ) = [ ∂ L ( f θ ∗ ) ∂ ω 1 , . . . , ∂ L ( f θ ∗ ) ∂ ω K ] T ∂ L ( f θ ∗ ) ∂ ω i = ∑ j ∂ L ( f θ ∗ ) ∂ θ j ∗ ∂ θ j ∗ ∂ ω i \\ \triangledown_{\omega} \mathcal L(f_{\theta^*}) = [\frac{\partial \mathcal L(f_{\theta^*})}{\partial \omega_1},...,\frac{\partial \mathcal L(f_{\theta^*})}{\partial \omega_K}]^\text T \\\; \\ \frac{\partial \mathcal L(f_{\theta^*})}{\partial \omega_i} = \sum_j \frac{\partial \mathcal L(f_{\theta^*})}{\partial \theta^*_j}\frac{\partial \theta^*_j}{\partial \omega_i} ▽ωL(fθ∗)=[∂ω1∂L(fθ∗),...,∂ωK∂L(fθ∗)]T∂ωi∂L(fθ∗)=j∑∂θj∗∂L(fθ∗)∂ωi∂θj∗
根据单步SGD的前提,有
θ j ∗ = ω j − ε ∂ L ( f ω ) ∂ ω j \theta^*_j = \omega_j-\varepsilon \frac{\partial \mathcal L(f_\omega)}{\partial \omega_j} θj∗=ωj−ε∂ωj∂L(fω)
上述求导过程涉及到对梯度函数求梯度,结果存在二阶导,如下所示
∂ θ j ∗ ∂ ω i = { 1 − ε ∂ 2 L ( f ω ) ∂ ω j ∂ ω i i = j − ε ∂ 2 L ( f ω ) ∂ ω j ∂ ω i i ≠ j \frac{\partial \theta^*_j}{\partial \omega_i} = \left\{\begin{matrix} 1-\varepsilon \frac{\partial^2 \mathcal L(f_\omega)}{\partial \omega_j \partial \omega_i} & i= j\\ -\varepsilon \frac{\partial^2 \mathcal L(f_\omega)}{\partial \omega_j \partial \omega_i} & i\not = j \end{matrix}\right. ∂ωi∂θj∗={1−ε∂ωj∂ωi∂2L(fω)−ε∂ωj∂ωi∂2L(fω)i=ji=j
对表达式进行一阶近似,假定 ε → 0 + \varepsilon \rightarrow 0^+ ε→0+ ,有
∂ θ j ∗ ∂ ω i = { 1 i = j 0 i ≠ j \frac{\partial \theta^*_j}{\partial \omega_i} = \left\{\begin{matrix} 1 & i= j\\ 0 & i\not = j \end{matrix}\right. ∂ωi∂θj∗={10i=ji=j
因此,代入结果有
▽ ω L ( f θ ∗ ) ≈ ▽ θ ∗ L ( f θ ∗ ) \\ \triangledown_{\omega} \mathcal L(f_{\theta^*}) \approx \triangledown_{\theta^*} \mathcal L(f_{\theta^*}) ▽ωL(fθ∗)≈▽θ∗L(fθ∗)
一阶近似的MAML元知识更新式子表示为
ω ← ω − ϵ ∑ T i ∼ p ( T ) ▽ θ ∗ ( i ) L T i ( f θ ∗ ( i ) ) θ ∗ ( i ) = ω − ε ▽ ω L T i ( f ω ) \omega\leftarrow \omega -\epsilon \sum_{\mathcal T_i\sim p(\mathcal T)} \triangledown_{{\theta^*}^{(i)}} \mathcal L_{\mathcal T_i}(f_{{\theta^*}^{(i)}}) \\ {\theta^*}^{(i)} = \omega - \varepsilon \triangledown_{\omega} \mathcal L_{\mathcal{T}_i}(f_\omega) ω←ω−ϵTi∼p(T)∑▽θ∗(i)LTi(fθ∗(i))θ∗(i)=ω−ε▽ωLTi(fω)
RL任务定义为
T i = ( q i ( x 1 ) , q i ( x t + 1 ∣ x t , a t ) , L T i , R i ) \mathcal T_i = (q_i(x_1),q_i(x_{t+1}|x_t,a_t),\mathcal L_{\mathcal T_i},R_i) Ti=(qi(x1),qi(xt+1∣xt,at),LTi,Ri)
其中 q q q 表示初始状态分布和状态转移分布,损失函数表示为
L T i ( f ω ) = − E x h , a h ∼ f ω , q T i [ ∑ h = 1 H R i ( x h , a h ) ] \mathcal L_{\mathcal T_i}(f_\omega) = -\mathbb E_{x_h,a_h\sim f_\omega,q_{\mathcal T_i}}[\sum_{h=1}^HR_i(x_h,a_h)] LTi(fω)=−Exh,ah∼fω,qTi[h=1∑HRi(xh,ah)]
对于RL任务,通常使用Policy Gradient 方法进行梯度估计。
作者通过实验观察到,一阶近似的性能与使用二阶导数获得的性能几乎相同,这表明MAML的大部分改进都来自目标在更新后参数值处的一阶梯度,而不是来自更新后参数值的二阶梯度。 过去的工作已经观察到ReLU神经网络在局部几乎是线性的,这表明在大多数情况下二阶导数可能接近于零,部分解释了一阶近似的良好性能。
作者给出了一种基于梯度下降来学习模型初始化参数的元学习方法,方法简单,不会为元学习引入任何学习的参数。它可以与任何适合基于梯度训练的模型表示以及任何可区分的目标(包括分类,回归和强化学习)相结合。MAML的训练过程中只考虑了内部模型参数的一步更新,但是在测试时可以进行充分的finetune。这项工作是迈向一种简单通用的元学习技术的一步,该技术可应用于任何问题和模型。