考虑用 f f f来表示一个模型,该模型能将observation x x x 映射到输出 a a a上。
在元学习中,每一个完整的任务被看成一个训练样本,训练所得模型的作用是要使其能够适用于其他新的任务。因此采用一个通用的表示方法来代表每个任务。
T = { L ( x 1 , a 1 , . . . , x H , a H ) , q ( x 1 ) , q ( x t + 1 ∣ x t , a t ) , H } T=\left\{L(x_1,a_1,...,x_H,a_H),q(x_1),q(x_{t+1}|x_t,a_t),H \right\} T={L(x1,a1,...,xH,aH),q(x1),q(xt+1∣xt,at),H}
L L L表示损失函数, q ( x 1 ) q(x_1) q(x1)是初始观察量的分布, q ( x t + 1 ∣ x t , a t ) q(x_{t+1}|x_t,a_t) q(xt+1∣xt,at)为过渡分布(transition distribution), H H H是片段长度(episode length)。在独立同分布问题中, H = 1 H=1 H=1。
考虑训练一个模型,可适用于服从分布 p ( T ) p(T) p(T)的任务集合。以K-shot learning为例,每个训练任务中仅有K个样本,从 p ( T ) p(T) p(T)中sample一个新任务 T i T_i Ti,模型在新任务的测试集上所得的损失 L T i L_{T_i} LTi即为元学习的损失。
本文通过元学习提出一种能够学习任何标准模型参数的方法,使模型能够快速适应新任务。其方法背后的intuition:网络中的某些中间特征表示比其他的更具有迁移性,可以广泛应用到服从分布 p ( T ) p(T) p(T)的所有任务上, 而不仅仅是一个单一的任务。
首先,将模型表示成一个参数化函数 f θ f_{\theta} fθ, θ \theta θ是模型的参数,当模型应用到一个新任务 T i T_i Ti上时,采用一步或多步梯度下降来更新模型参数:(以一次梯度更新为例)
通过优化所有 f θ i ′ f_{\theta_i'} fθi′在其相应的任务 T i T_i Ti上的表现,来训练模型的初始参数 θ \theta θ,meta-objective表示如下:
同样采用梯度下降法来更新 θ \theta θ:
算法流程如下:
- 针对待解决的任务选择模型,并初始化模型参数 θ \theta θ
- 从分布 p ( T ) p(T) p(T)中采样一组训练任务 T i T_i Ti,对所有的任务进行以下步骤:
a. 计算任务 T i T_i Ti中 K K K个样本上的损失 ▽ θ L T i ( f θ ) \triangledown_\theta L_{T_i}(f_\theta) ▽θLTi(fθ)
b. 采用梯度下降算法更新模型参数
θ i ′ = θ − α ▽ θ L T i ( f θ ) \theta'_i =\theta -\alpha\triangledown_\theta L_{T_i}(f_\theta) θi′=θ−α▽θLTi(fθ)
最终目标是要找到一个最佳的模型初始参数 θ \theta θ,使得网络只需要进行少数更新就能在所有任务上都能达到最佳的效果。即找到一个 θ \theta θ使 ∑ T i ∼ p ( T ) L T i ( f θ i ′ ) \sum_{T_i\sim p(T)}L_{T_i}(f_{\theta'_i}) ∑Ti∼p(T)LTi(fθi′)最小。- 采用梯度下降算法优化初始参数 θ \theta θ:
θ ← θ − β ▽ θ ∑ T i ∼ p ( T ) L T i ( f θ i ′ ) \theta\leftarrow\theta-\beta \triangledown_\theta\sum_{T_i\sim p(T)}L_{T_i}(f_{\theta'_i}) θ←θ−β▽θ∑Ti∼p(T)LTi(fθi′)
至此,得到最终的最优初始参数 θ \theta θ。
实验待验证的问题有三点:
待解决任务:通过一系列数据点来拟合一条正弦曲线,即给定 { ( x i , y i ) } i = 1 , . . . , K {\left\{(x_i,y_i) \right\}}_{i=1,...,K} {(xi,yi)}i=1,...,K,来预测正弦函数的幅值 A A A和相角 ϕ \phi ϕ,其中 A ∈ [ 0.1 , 5.0 ] , ϕ ∈ [ 0 , π ] , x i ∈ [ − 5.0 , 5.0 ] A\in[0.1,5.0],\phi\in[0,\pi],x_i\in[-5.0,5.0] A∈[0.1,5.0],ϕ∈[0,π],xi∈[−5.0,5.0]。
上图实验结果表明:
在数据集Omniglot和MiniImagenet上进行N-way K-shot的实验(K=1 or 5),实验结果如下:
本文提出一种利用梯度下降来学习具备easily adaptable的模型参数的方法,其优势如下:
1)流程简单,且不引入额外需要学习的参数;
2)可以适用于任意能够采用梯度下降来训练的模型;
3)由于本方法仅产生一组初始权重,因此adaptation的过程可以通过任意数量的数据、任意次数的梯度更新来实现。