Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks

MAML的核心思想是利用元学习来找到一个好的模型初始化,从而能够在新任务上进行快速适应。这种方法旨在处理“少样本学习”的挑战,即当新任务的数据量非常有限时如何有效地学习。传统学习的数据点是一个样本,而元学习的数据点是一个小数据集(任务),任务包含了很多样本。元学习对每个任务中的每个样本进行训练得到每个任务的loss,并得到任务的损失和losses。对losses进行优化来更新元学习模型的参数。

MAML:

摘要:提出一个模型无关的元学习算法,它与任何由梯度下降训练的模型兼容并且可以应用到各种不同的学习问题,包括分类,回归,强化学习。元学习的目标是在各种学习任务上训练一个模型,它可以仅仅使用小数量的训练样本来解决新的学习任务。在我们的方法中,模型的参数被明确地训练,这样少量的梯度步长和来自新任务的少量训练数据将在该任务上产生良好的泛化性能。该方法训练模型更容易去微调。在两个小样本的图像分类上得到了sota的性能,在小样本回归上也得好的结果,并且加速了使用神经网络策略的策略梯度强化的微调。

引言:

问题:

从小样本得到认知目标或者快速的学习新技能属于人类擅长的事。而智能机器学习这方面的能力存在挑战。因为代理必须将其以前的经验与少量的新信息集成起来,同时避免对新数据进行过拟合(只学会了这几个样本,并没有学习到能力)。此外,先前的经验和新数据的形式将取决于任务本身。

重要性:

因此,提出的方法应该对任务和实现任务的方法通用。

难点:

创思:

在这项工作中,提出了一个元学习算法MAML,与特定模型无关,即它可以直接应用于任何可微的模型。MAML聚焦在深度神经网络,阐释了如何用一个最小步数的微调,便可以更容易处理不同的网络结构和不同的问题,包括分类,回归,策略梯度强化学习。

提出的方法关注学习模型的初始化参数。以便新任务再模型上通过少量的样本和迭代可以进行快速适应。与先验元学习方法和学习更新函数或者更新规则不同,算法没有扩展到学习参数或模型结构的数量上(有论文已经做了结构和数量的了)。MAML可以组合全连接,卷积,RNN,不同的损失函数,包括可微分的监督损失和不可微分的强化学习目标。

模型参数的训练过程,通过几个或者一个梯度更新步骤,简单的微调参数可以得到好的结果。事实上,模型的优化是容易且快速的,允许在正确的空间快速学习。学习的过程可以被看作最大化新任务损失函数对参数的敏感性。当敏感性高的时候,对于参数的小的局部的改变可以导致在任务损失上的提升。

结果:

评估MAML相比流行的SOTA的专门为监督分类设计的one-shot 学习方法。方法使用小的参数,但也可以容易的应用到回归以及强化学习,归功于直接预训练初始参数使得性能提升。

假设:

模型:

MAML:随机初始化模型参数,通过训练来学习最优的初始化参数。初始化参数的训练主要分为两步,第一步是任务内的参数更新,第二步是任务间的参数更新

MAML 框架

其中:

Require 给出所有任务的分布以及参数更新的学习率

1 、随机初始化模型参数;2、 循环训练更新参数,直到训练截止;3 、采样一个batch,包含多个任务,每个任务K个样本;4、遍历所有任务;5、计算第i个任务在lossL下的梯度;6 、任务内的参数更新;7、batch中的任务内参数更新完成;8、任务间的参数更新


不同的任务需要选择不同的loss,在回归和分类的算法上的应用时,loss的选择为均方误差和交叉熵;在算法1中具体化任务和问题得到算法2:

算法2,在监督回归和分类的算法上的应用的算法

在强化学习上的MAML,loss为奖励函数,模型输出为决策,

MAML for Reinforcement Learning

实验:

实验回答论文2个问题(这种先描述问题的方法可以借鉴到写作上):

1)MAML可以在新任务上快速的学习吗?  

2)模型用MAML,在额外的更新次数和样本个数上可以连续的提升性能?

回归任务,用样本做sin函数回归

回归任务

pretrained的方法只做一次参数更新,而MAML做两次参数更新,第一次更新为下一次更新确定方向。不同的梯度次数训练得到的预测结果不同,从图中可以看到K=5和K=10时10次更新结果最好,1次梯度下降有不错的效果,能够得到快速的适应,回答了任务1。随着更新次数(grad step)和样本个数K的提高,性能得到了提升,回答了问题2。预训练的方法没有元参数更新的步骤,效果都很差,很难拟合。

回归MSE

通过loss值可以看出MAML在步数增加的情况没有过拟合,loss更低,性能持续提高,回答了问题2。

分类实验:

Datasets:Omniglot,MiniImagenet

Omniglot:来自50个不同的字母(类),1623个样本,选择20个类。1200个作为训练集,剩下的做测试集。

MiniImagenet:64个训练类,12个验证类,24个测试类

分类实验结果

baseline:

MANN:Memory-Augmented Neural Networks 记忆增强的神经网络

Siamese nets 孪生网络,共享encoder权重

matching nets 匹配网络,few-shot learning方法,用目标样本和支持集一起做嵌入,后计算二者的相似度作为权重,为支持集赋予权重预测标签。

neural statistician 神经统计师模型,包括encoder,统计网络(有很多不同的统计方式),decoder。统计网络的任务是将所有样本的特征整合,输出一个集合表示,即统计信息【加一些额外的设计和策略,神经统计师是否可以被扩展并应用于演化聚类任务?】

memory mod. 记忆增强的神经网络的一种,原文提到运用到life-long中受限。

meta-learner LSTM 在元学习场景中使用的LSTM,LSTM接受梯度信息,输出应该应用于模型权重的更新。LSTM被看作一个优化器。

MAML first order approx 代表的是梯度之考虑一次微分,二次微分因为会带来计算开销被忽略。


分类code:

maml pytorch代码:https://github.com/dragen1860/MAML-Pytorch/blob/master/meta.py

代码里的实现,对每个任务,先初始化参数,对初始化的模型参数进行训练得到第一次参数,在第一次参数的更新方向上更新了初始参数。也就是第一次参数的更新决定了更新方向,第二次更新更新了实际参数。

对batch,batch中每个任务学习对应的任务loss,将每个loss求和得到整体losses,并对losses进行优化。

微调过程:copy训练好的模型,在模型上进行微调和验证。在测试集学习每个任务的loss,并得到losses和更新权重。分别对任务中的样本在新权重下进行测试。

强化学习(实验部分很难看懂,以后补充)

       讨论和未来工作:介绍了一种基于元学习的方法,该方法基于通过梯度下降学习易于适应的模型参数。方法有很多好处,它很简单,并且没有为元学习引入任何学习参数。它可以组合任何可以用基于梯度训练的模型,任何可以微分的目标,包括分类,回归,强化学习。模型仅仅产生权重的初始化,适应任何数据数K和梯度步骤数setp grad,通过SOTA的分类结果,也在RL上使用了策略梯度。从过去的任务中重用知识可能是制作高容量可扩展模型(例如深度神经网络)的关键因素,可以使用小数据集进行快速训练。这项工作是迈向简单通用元学习技术的第一步,可应用于任何问题和任何模型。该领域的进一步研究可以使多任务初始化成为深度学习和强化学习的标准成分。非常有用的工作!

你可能感兴趣的:(Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks)