Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
用于深度网络快速自适应的模型不可知元学习
元学习的目标是在各种学习任务上训练一个模型,学习一个模型初始化,这样它可以只使用少量的训练样本来解决新的学习任务。也就是说训练一个具体对各种任务都有极强泛化性的模型,在新任务中只需要小样本训练对参数进行微调即可。文章中,该方法可以用到分类、回归和强化学习的方法中。
快速学习要求对大量任务训练学习先验知识,再将其与新任务的数据相结合,并防止在新任务中过拟合。
在元学习中,目标是从少量新数据中快速学习新任务,元学习器训练模型以学习大量不同的任务。其关键思想是训练模型的初始参数,以便在通过一个或多个梯度步长更新参数后,模型在新任务上具有最大性能,该梯度步长是用来自该新任务的少量数据计算的。
元学习的新任务快速学习可以视为构建广泛适用于许多任务的内部表示。如果内部表示适用于许多任务,那么简单地稍微微调参数可以产生良好的结果。
并且元学习与模型的类型无关。
元学习的目标是训练能够实现快速适应的模型,这一问题设置通常被正式化为少量学习。
元学习的目标是训练一个模型,该模型仅使用几个数据点和训练迭代就能快速适应新任务。方法是在元学习阶段对模型或学习者进行一组任务的训练。
一个任务的定义形式如下:
L为损失函数, q ( x 1 ) q(x1) q(x1)为初始观测的概率, q ( x t + 1 ∣ x t , a t ) q(x_{t+1}|x_{t}, a_{t}) q(xt+1∣xt,at)为状态转移概率,H为一个episode的长度。(任务定义更针对强化学习)对分类和回归任务H一般为1。
先前的工作试图训练摄取整个数据集的递归神经网络或可在测试时与非参数方法结合的特征嵌入。
元学习思想是,一些内部表征比其他表征更容易传递。例如,神经网络可能学习广泛适用于p(T)中所有任务的内部特征,而不是单个任务。我们如何鼓励这种通用表示的出现?我们的目标是找到对任务变化敏感的模型参数,这样,当沿着损失梯度的方向改变时,参数的微小变化将对从p(T)得出的任何任务的损失函数产生很大的改善。
元学习的算法流程如下:
其中2-8为外部循环,4-7为内部循环。
2:开始循环
3:首先这里会采样多个任务
4:对于各个任务进行内部循环
5:对于各个任务中采样得到的K个样本(训练集)根据损失计算参数 θ \theta θ梯度
6:使用梯度下降计算当前的自适应参数 θ ′ \theta^{'} θ′,计算公式为:
注意我们这里并没有直接使用 θ ′ \theta^{'} θ′来替换 θ \theta θ,而仅仅是计算了 θ ′ \theta^{'} θ′的值,这是为了进一步计算下一步更新的梯度。
7: 结束内循环
8:外循环最重要的一步,更新任务的参数目标 θ \theta θ,更新公式为
注意这里是使用每个任务的测试集来更新。同时注意求导过程中,这里是使用的各个任务中基于 θ ′ \theta^{'} θ′的模型对于初始参数 θ \theta θ的梯度的和。
补充:注意这个求导公式涉及到了 θ \theta θ的二阶导。如下图,由于首先需要对 θ ′ \theta^{'} θ′进行求导,进一步对 θ ′ \theta^{'} θ′求 θ \theta θ的倒数,推导如下:
但是在实现过程中MAML对这个二阶导的计算做了近似,因为不近似的话二阶导要保存计算图,存储空降和计算速度都会受到影响,会花费大量的计算时间。这里近似把二阶导数置为0。
因此在实际代码中 f ( θ ′ ) f(\theta^{'}) f(θ′)对 θ \theta θ求导等价于 f ( θ ′ ) f(\theta^{'}) f(θ′)对 θ ′ \theta^{'} θ′求导。如下是计算时的关键代码。
for i in range(task_num):
# 1. run the i-th task and compute loss for k=0
logits = self.net(x_spt[i], vars=None, bn_training=True)
loss = F.cross_entropy(logits, y_spt[i])
#可以看到下面的这个grad的计算图没有保存
grad = torch.autograd.grad(loss, self.net.parameters())
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))
for k in range(1, self.update_step): #第二步更新了
logits = self.net(x_spt[i], fast_weights, bn_training=True)
loss = F.cross_entropy(logits, y_spt[i])
grad = torch.autograd.grad(loss, fast_weights)
#这里就不用对net的参数求导,近似为对fastw求导
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))
logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
# loss_q will be overwritten and just keep the loss_q on last update step.
loss_q = F.cross_entropy(logits_q, y_qry[i])
losses_q[k + 1] += loss_q
self.meta_optim.zero_grad()
loss_q.backward()
self.meta_optim.step()
#这里的loss是对net的参数求导,虽然里面有fastw,但由于没有保存计算图,所以其对net的导数为1
代码来自:https://blog.csdn.net/Cecilia6277/article/details/109091482
因此,个人认为元学习这篇文章主要的几点如下: