MAML

Paper : Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
Code : official

摘要

作者根据元学习(meta learning)的表达式提出了MAML算法用来进行元知识的梯度下降,使用一阶近似的方法来避免计算损失函数的二阶导,并在小样本学习任务(few-shot learning)上取得了SOTA的成绩。作者强调MAML算法具有模型无关性,可以适用于任何基于梯度下降优化的模型上。并给出了MAML与监督学习的结合和MAML与强化学习结合的算法,强调了算法的通用性。

Meta-Learning

在此blog中,Meta-Learning 采用的是 Meta-Learning in Neural Networks: A Survey 一文中的形式化定义方式,之后的话博主会将这篇论文的blog编出来。

元学习的直观理解是 “Learning to learn”,也就是说通过多个任务的表现来改善学习算法本身。考虑对比常规的机器学习,常规监督学习的形式化表示如下:

给定训练集 D = { ( x 1 , y 1 ) , . . . ( x N , y N ) } \mathcal D = \{(x_1,y_1),...(x_N,y_N)\} D={(x1,y1),...(xN,yN)},希望找到一个预测模型 y ^ = f θ ( x ) \widehat y = f_{\theta}(x) y =fθ(x) 在训练集上的最优参数解 θ ∗ \theta^* θ,即通过下式求解

θ ∗ = arg ⁡ min ⁡ θ ( D ; θ , ω ) \theta^* = \arg\min_{\theta} \mathcal(\mathcal D;\theta,\omega) θ=argθmin(D;θ,ω)

这里使用 ω \omega ω 表示该解对某些因素的依赖性,例如针对 θ \theta θ 的优化器选择或针对 f f f 的模型的选择,常见的 ω \omega ω 包括优化器的初始化(SGD的步长等), f f f 模型的参数初始化方法,正则化强度等等。 ω \omega ω 被称为是元知识(meta-knowledge),对于常规的机器学习任务来说,元知识是被人为设定的。元学习就是对元知识进行学习和优化。

学习元知识的过程形式化的表示为如下优化问题的求解过程

min ⁡ ω E T ∼ p ( T ) L ( D , ω ) \min_\omega \mathbb{E}_{\mathcal{T}\sim p(\mathcal T)}\mathcal L(\mathcal D,\omega) ωminETp(T)L(D,ω)

其中 T = { D , L } \mathcal T = \{\mathcal D,\mathcal L\} T={D,L} 表示一个常规机器学习任务,假定多个常规任务都是从某个任务分布 p ( T ) p(\mathcal T) p(T) 中采样出来的。我们希望学到的是跨任务的元知识 ω \omega ω ,这些知识可以泛化到一个之前没有遇到过的数据集上,有助于模型在小样本数据集上进行学习。形式化的表述meta-training过程:

给定一个包含 M 个学习任务的数据集 D = { ( D train ( i ) , D val ( i ) ) } \mathbb D = \{(\mathcal D^{(i)}_\text{train},\mathcal D^{(i)}_\text{val})\} D={(Dtrain(i),Dval(i))},求解问题可以表示为

ω ∗ = arg ⁡ max ⁡ ω log ⁡ p ( ω ∣ D ) \omega^* = \arg \max_{\omega} \log p(\omega|\mathbb D) ω=argωmaxlogp(ωD)

在meta-testing的过程中,首先需要根据元知识进行学习,然后再进行模型的评估,形式化表述为:

给定某训练阶段不可知的任务 j \mathcal j j,测试模型的参数定义为

θ ∗ ( j ) = arg ⁡ max ⁡ θ log ⁡ p ( θ ∣ ω ∗ , D train ( j ) ) {\theta^{*}}^{(j)} = \arg \max_{\theta} \log p(\theta|\omega^*,\mathcal D^{(j)}_\text{train}) θ(j)=argθmaxlogp(θω,Dtrain(j))

从双层优化问题的角度来理解 meta-learning,meta-training可以形式化的表示为下式

ω ∗ = arg ⁡ min ⁡ ω ∑ i = 1 M L meta ( θ ∗ ( i ) ( ω ) , ω , D val ( i ) ) s.t.  θ ∗ ( i ) ( ω ) = arg ⁡ min ⁡ θ L task ( θ , ω , D train ( i ) ) \\\omega^* = \arg\min_\omega \sum_{i=1}^M \mathcal L^\text{meta}({\theta^*}^{(i)}(\omega),\omega,\mathcal D_\text{val}^{(i)}) \\\text{s.t. }{\theta^*}^{(i)}(\omega) = \arg\min_{\theta} \mathcal L^\text{task}(\theta,\omega,\mathcal D_\text{train}^{(i)}) ω=argωmini=1MLmeta(θ(i)(ω),ω,Dval(i))s.t. θ(i)(ω)=argθminLtask(θ,ω,Dtrain(i))

其中 L task \mathcal L^\text{task} Ltask L meta \mathcal L^\text{meta} Lmeta 分别对应内层和外层的优化目标(损失函数)。

对于few-shot learning来说,一个常见的术语是 N-way K-shot classification,表示对于分类任务,类别总数有 N 个,每个类下面有 K 个样本。

MAML

MAML算法的前提是,存在对于模型来说存在某些参数初始化,比其他的初始化方法具有更好的迁移性,更适合做迁移学习。对于MAML算法来说,元知识 ω \omega ω 表示模型的初始化参数,想要解决的问题是小样本学习的问题。小样本集意味着不能在复杂的模型上进行多轮训练,不然会产生overfit问题。MAML通过元学习的方法学到一种对新任务损失函数敏感的初始化方法,使得模型在初始化后经过较少的epoch就可以finetune到一个比较良好的表现上。

我们的目标是得到一个初始化模型参数可以经过较少的epoch获得一个良好的表现,因此,为了简化起见,假定 θ ∗ ( i ) {\theta^*}^{(i)} θ(i) 表示进行了一步梯度下降的结果,即

θ ∗ ( i ) = ω − ε ▽ ω L T i ( f ω ) {\theta^*}^{(i)} = \omega - \varepsilon \triangledown_{\omega} \mathcal L_{\mathcal{T}_i}(f_\omega) θ(i)=ωεωLTi(fω)

而外层的优化目标为

min ⁡ ω ∑ T i ∼ p ( T ) L T i ( f θ ∗ ( i ) ) \min_\omega \sum_{\mathcal T_i\sim p(\mathcal T)}\mathcal L_{\mathcal T_i}(f_{{\theta^*}^{(i)}}) ωminTip(T)LTi(fθ(i))

使用SGD算法优化元知识 ω \omega ω ,即

ω ← ω − ϵ ▽ ω ∑ T i ∼ p ( T ) L T i ( f θ ∗ ( i ) ) \omega\leftarrow \omega -\epsilon \triangledown_{\omega} \sum_{\mathcal T_i\sim p(\mathcal T)}\mathcal L_{\mathcal T_i}(f_{{\theta^*}^{(i)}}) ωωϵωTip(T)LTi(fθ(i))

考虑将 ω \omega ω θ ∗ ( i ) {\theta^*}^{(i)} θ(i) 都表示为向量,有

▽ ω L ( f θ ∗ ) = [ ∂ L ( f θ ∗ ) ∂ ω 1 , . . . , ∂ L ( f θ ∗ ) ∂ ω K ] T    ∂ L ( f θ ∗ ) ∂ ω i = ∑ j ∂ L ( f θ ∗ ) ∂ θ j ∗ ∂ θ j ∗ ∂ ω i \\ \triangledown_{\omega} \mathcal L(f_{\theta^*}) = [\frac{\partial \mathcal L(f_{\theta^*})}{\partial \omega_1},...,\frac{\partial \mathcal L(f_{\theta^*})}{\partial \omega_K}]^\text T \\\; \\ \frac{\partial \mathcal L(f_{\theta^*})}{\partial \omega_i} = \sum_j \frac{\partial \mathcal L(f_{\theta^*})}{\partial \theta^*_j}\frac{\partial \theta^*_j}{\partial \omega_i} ωL(fθ)=[ω1L(fθ),...,ωKL(fθ)]TωiL(fθ)=jθjL(fθ)ωiθj

根据单步SGD的前提,有

θ j ∗ = ω j − ε ∂ L ( f ω ) ∂ ω j \theta^*_j = \omega_j-\varepsilon \frac{\partial \mathcal L(f_\omega)}{\partial \omega_j} θj=ωjεωjL(fω)

上述求导过程涉及到对梯度函数求梯度,结果存在二阶导,如下所示

∂ θ j ∗ ∂ ω i = { 1 − ε ∂ 2 L ( f ω ) ∂ ω j ∂ ω i i = j − ε ∂ 2 L ( f ω ) ∂ ω j ∂ ω i i ≠ j \frac{\partial \theta^*_j}{\partial \omega_i} = \left\{\begin{matrix} 1-\varepsilon \frac{\partial^2 \mathcal L(f_\omega)}{\partial \omega_j \partial \omega_i} & i= j\\ -\varepsilon \frac{\partial^2 \mathcal L(f_\omega)}{\partial \omega_j \partial \omega_i} & i\not = j \end{matrix}\right. ωiθj={1εωjωi2L(fω)εωjωi2L(fω)i=ji=j

对表达式进行一阶近似,假定 ε → 0 + \varepsilon \rightarrow 0^+ ε0+ ,有

∂ θ j ∗ ∂ ω i = { 1 i = j 0 i ≠ j \frac{\partial \theta^*_j}{\partial \omega_i} = \left\{\begin{matrix} 1 & i= j\\ 0 & i\not = j \end{matrix}\right. ωiθj={10i=ji=j

因此,代入结果有

▽ ω L ( f θ ∗ ) ≈ ▽ θ ∗ L ( f θ ∗ ) \\ \triangledown_{\omega} \mathcal L(f_{\theta^*}) \approx \triangledown_{\theta^*} \mathcal L(f_{\theta^*}) ωL(fθ)θL(fθ)

一阶近似的MAML元知识更新式子表示为

ω ← ω − ϵ ∑ T i ∼ p ( T ) ▽ θ ∗ ( i ) L T i ( f θ ∗ ( i ) ) θ ∗ ( i ) = ω − ε ▽ ω L T i ( f ω ) \omega\leftarrow \omega -\epsilon \sum_{\mathcal T_i\sim p(\mathcal T)} \triangledown_{{\theta^*}^{(i)}} \mathcal L_{\mathcal T_i}(f_{{\theta^*}^{(i)}}) \\ {\theta^*}^{(i)} = \omega - \varepsilon \triangledown_{\omega} \mathcal L_{\mathcal{T}_i}(f_\omega) ωωϵTip(T)θ(i)LTi(fθ(i))θ(i)=ωεωLTi(fω)

MAML在不同任务上的应用

Few-Shot Supervised Learning

MAML_第1张图片

Reinforcement Learning

RL任务定义为

T i = ( q i ( x 1 ) , q i ( x t + 1 ∣ x t , a t ) , L T i , R i ) \mathcal T_i = (q_i(x_1),q_i(x_{t+1}|x_t,a_t),\mathcal L_{\mathcal T_i},R_i) Ti=(qi(x1),qi(xt+1xt,at),LTi,Ri)

其中 q q q 表示初始状态分布和状态转移分布,损失函数表示为

L T i ( f ω ) = − E x h , a h ∼ f ω , q T i [ ∑ h = 1 H R i ( x h , a h ) ] \mathcal L_{\mathcal T_i}(f_\omega) = -\mathbb E_{x_h,a_h\sim f_\omega,q_{\mathcal T_i}}[\sum_{h=1}^HR_i(x_h,a_h)] LTi(fω)=Exh,ahfω,qTi[h=1HRi(xh,ah)]

对于RL任务,通常使用Policy Gradient 方法进行梯度估计。

MAML_第2张图片

实验

作者通过实验观察到,一阶近似的性能与使用二阶导数获得的性能几乎相同,这表明MAML的大部分改进都来自目标在更新后参数值处的一阶梯度,而不是来自更新后参数值的二阶梯度。 过去的工作已经观察到ReLU神经网络在局部几乎是线性的,这表明在大多数情况下二阶导数可能接近于零,部分解释了一阶近似的良好性能。

MAML_第3张图片

MAML与Transfer Learning

MAML_第4张图片

总结

作者给出了一种基于梯度下降来学习模型初始化参数的元学习方法,方法简单,不会为元学习引入任何学习的参数。它可以与任何适合基于梯度训练的模型表示以及任何可区分的目标(包括分类,回归和强化学习)相结合。MAML的训练过程中只考虑了内部模型参数的一步更新,但是在测试时可以进行充分的finetune。这项工作是迈向一种简单通用的元学习技术的一步,该技术可应用于任何问题和模型。

你可能感兴趣的:(元学习,元学习)