元学习深度解析

引言

 元学习( m e t a \mathrm{meta} meta- l e a r n i n g \mathrm{learning} learning)是过去几年最火爆的学习方法之一,各式各样的 p a p e r \mathrm{paper} paper都是基于元学习展开的。深度学习模型训练模型特别吃计算硬件,尤其是人为调超参数时候,更需要大量的计算。另一个头疼的问题是在某个任务下大量数据训练的模型,切换到另一个任务后,模型就需要重新训练,这样非常耗时耗力。工业界财大气粗有大量的GPU可以承担起这样的计算成本,但是学术界因为经费有限经不起这样的消耗。元学习可以有效的缓解大量调参和任务切换模型重新训练带来的计算成本问题。

元学习介绍

元学习希望使得模型获取一种学会学习调参的能力,使其可以在获取已有知识的基础上快速学习新的任务。机器学习是先人为调参,之后直接训练特定任务下深度模型。元学习则是先通过其它的任务训练出一个较好的超参数,然后再对特定任务进行训练。
元学习深度解析_第1张图片

 在机器学习中,训练单位是样本数据,通过数据来对模型进行优化;数据可以分为训练集、测试集和验证集。在元学习中,训练单位是任务,一般有两个任务分别是训练任务( T r a i n   T a s k s \mathrm{Train\text{ }Tasks} Train Tasks)亦称跨任务( A c r o s s   T a s k s \mathrm{Across\text{ }Tasks} Across Tasks)和测试任务( T e s t   T a s k \mathrm{Test\text{ }Task} Test Task)亦称单任务( W i t h i n   T a s k \mathrm{Within\text{ }Task} Within Task)。训练任务要准备许多子任务来进行学习,目的是学习出一个较好的超参数,测试任务是利用训练任务学习出的超参数对特定任务进行训练。训练任务中的每个任务的数据分为 S u p p o r t   s e t \mathrm{Support\text{ }set} Support set Q u e r y   s e t \mathrm{Query\text{ }set} Query set T e s t   T a s k \mathrm{Test\text{ }Task} Test Task中数据分为训练集和测试集。

 令 φ \varphi φ表示需要设置的超参数, θ \theta θ表示神经网络待训练的参数。元学习的目的就是让函数 F φ , θ F_{\varphi,\theta} Fφ,θ在训练任务中自动训练出最好的 φ ∗ \varphi^{*} φ,再利用 φ ∗ \varphi^{*} φ这个先验知识在测试任务中训练出特定任务下模型 f θ f_\theta fθ中的参数 θ \theta θ,如下所示的依赖关系:
F φ , θ ↦ T r a i n   T a s k s ( F φ ∗ , θ ⇔ f θ ) ↦ T e s t   T a s k ( F φ ∗ , θ ∗ ⇔ f θ ∗ ) F_{\varphi,\theta} \xmapsto{\mathrm{Train\text{ }Tasks}}(F_{\varphi^{*},\theta}\Leftrightarrow{} f_{\theta}) \xmapsto{\mathrm{Test\text{ }Task}}(F_{\varphi^{*},\theta^*}\Leftrightarrow{}f_{\theta^*}) Fφ,θTrain Tasks (Fφ,θfθ)Test Task (Fφ,θfθ)当训练一个神经网络的时候,具体一般步骤有,预处理数据集 D D D,选择网络结构 N N N,设置超参数 γ \gamma γ,初始化参数 θ 0 \theta_0 θ0,选择优化器 O O O,定义损失函数 L L L,梯度下降更新参数 θ \theta θ。具体步骤如下图所示
元学习深度解析_第2张图片
元学习会去学习所有需要由人去设置和定义的参数变量 φ \varphi φ。在这里参数变量 φ \varphi φ属于集合为 Φ \Phi Φ,则有 φ ∈ Φ = { D , N , γ , θ 0 , O , L } \varphi\in \Phi=\{D,N,\gamma,\theta_0,O,L\} φΦ={ D,N,γ,θ0,O,L}不同的元学习,就要去学集合 Φ \Phi Φ中不同的元素,相应的就会有不同的研究领域。

  • 学习预处理数据集 D D D:对数据进行预处理的时候,数据增强会增加模型的鲁棒性,一般的数据增强方式比较死板,只是对图像进行旋转,颜色变换,伸缩变换等。元学习可以自动地,多样化地为数据进行增强,相关的代表作为 D A D A \mathrm{DADA} DADA
    论文名称:DADA: Differentiable Automatic Data Augmentation
    论文链接:https://arxiv.org/pdf/2003.03780v1.pdf
    论文详情:ECCV, 2020

  • 学习初始化参数 θ 0 \theta_0 θ0:权重参数初始化的好坏可以影响模型最后的分类性能,元学习可以通过学出一个较好的权重初始化参数有助于模型在新的任务上进行学习。元学习学习初始化参数的代表作是 M A M L \mathrm{MAML} MAML( M o d e l \mathrm{Model} Model- A g n o s t i c \mathrm{Agnostic} Agnostic- M e t a \mathrm{Meta} Meta- L e a r n i n g \mathrm{Learning} Learning)。它专注于提升模型整体的学习能力,而不是解决某个具体问题的能力,训练时,不停地在不同的任务上切换,从而达到初始化网络参数的目的,最终得到的模型,面对新的任务时可以学习得更快。
    论文名称:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
    论文链接:https://arxiv.org/pdf/1703.03400.pdf
    论文详情:ICML, 2017

  • 学习网络结构 N N N:神经网络的结构设定是一个很头疼的问题,网络的深度是多少,每一层的宽度是多少,每一层的卷积核有多少个,每个卷积核的大小又该怎么定,需不需要 d r o p o u t \mathrm{dropout} dropout等等问题,到目前为止没有一个定论或定理能够清晰准确地回答出以上问题,所以神经网络结构搜索 N A S \mathrm{NAS} NAS运营而生。归根结底,神经网络结构搜索其实是元学习地一个子类领域。值得注意的是,网络结构的探索不能通过梯度下降法来获得,这是一个不可导问题,一般情况下会采用强化学习或进化算法来解决。
    论文名称:Neural Architecture Search with Reinforcement Learning
    论文链接:https://arxiv.org/abs/1611.01578
    论文详情:ICLR, 2017

  • 学习选择优化器 O O O:神经网络训练的过程中很重要的一环就是优化器的选取,不同的优化器会对优化参数时对梯度的走向有很重要的影响。熟知的优化器有 A d a m \mathrm{Adam} Adam R M s p r o p \mathrm{RMsprop} RMsprop S G D \mathrm{SGD} SGD N A G \mathrm{NAG} NAG等,元学习可以帮我们在训练特定任务前选择一个好的的优化器,其代表作有
    论文名称:Learning to learn by gradient descent by gradient descent
    论文链接:https://arxiv.org/pdf/1606.04474.pdf
    论文详情:NIPS, 2016

元学习训练

元学习分为两个阶段,阶段一是训练任务训练;阶段二为测试任务训练。对应于一些论文的算法流程图,训练任务是在 o u t e r   l o o p \mathrm{outer \text{ } loop } outer loop里,测试任务任务是在 i n n e r   l o o p \mathrm{inner \text{ } loop } inner loop里。

阶段一:训练任务训练

 在训练任务中给定 h h h个子训练任务,每个子训练任务的数据集分为 S u p p o r t   s e t \mathrm{Support\text{ }set} Support set Q u e r y   s e t \mathrm{Query\text{ }set} Query set。首先通过这 h h h个子任务的 S u p p o r t   s e t \mathrm{Support\text{ }set} Support set训练 F φ , θ F_{\varphi,\theta} Fφ,θ,分别训练出针对各自子任务的模型参数 θ i ∗ ( 1 ≤ i ≤ h ) \theta_i^{*}(1\le i \le h) θi(1ih)。然后用不同子任务中的 Q u e r y   s e t \mathrm{Query\text{ }set} Query set分别去测试 f θ i ∗ f_{\theta_i^{*}} fθi的性能,并计算出预测值和真实标签的损失 l i ( 1 ≤ i ≤ h ) l_{i}(1\le i \le h) li(1ih)。接着整合这 h h h个损失函数为 L ( φ ) L(\varphi) L(φ) L ( φ ) = l 1 + ⋯ + l k + ⋯ + l h L(\varphi)=l_1+\cdots+l_k+\cdots+l_h L(φ)=l1++lk++lh最后利用梯度下降法去求出 ∂ L ( φ ) ∂ φ \frac{\partial L(\varphi)}{\partial \varphi} φL(φ)去更新参数 φ \varphi φ,从而找到最优的超参设置;如果 ∂ L ( φ ) ∂ φ \frac{\partial L(\varphi)}{\partial \varphi} φL(φ)不可求,则可以采用强化学习或者进化算法去解决。阶段一中训练任务的训练过程被整理在如下的框图中。
元学习深度解析_第3张图片

阶段二:测试任务训练

 测试任务就是正常的机器学习的过程,它将数据集划分为训练集和测试集。阶段一中训练任务的目的是找到一个好的超参设置 φ ∗ \varphi^{*} φ,利用这个先验知识可以对特定的测试任务进行更好的进行训练。阶段二中测试任务的训练过程被整理在如下的框图中。
元学习深度解析_第4张图片

你可能感兴趣的:(论文解读,深度学习,机器学习,神经网络)