(前几天忙着处理联邦学习和终身学习任务,加上有点犯懒,没有坚持看论文,今天继续!!)
第一部分点击这里!!
Learning to Learning with Gradients———论文阅读第一部分
前三章我们主要探讨了元学习的基本概念,以及如何以数学方式去描述任何一个元学习算法,以及元学习应该具备的性质等,这一节,论文想提出一个通用的、与模型无关的元学习算法。作者主要关注的是如何去训练出事参数,使得模型在新任务重使用少量数据计算就能达到最大的性能。这也就是MAML!!(很重要的元学习算法)
作者首先提到了,在神经网络中,可能可以学习到适用于所有任务分布的内部特征(也就是任务分布中的关键信息都能get到)而不是只针对于一个任务。换句话说,作者目的是让模型在新任务上使用基于梯度的学习规则进行微调(这里可以理解为,当我们把模型直接拿来测试效果很差,但我只需要从测试集上很少抽一部分进行几步的训练,就能得到很好的结果,也就是fine-tune)。也就是找到对任务变化很敏感的参数,当进行梯度计算时(基于损失的梯度方向),这样微调就能带来很大的变化。如下图
θ \theta θ是我们的元学习参数, ϕ \phi ϕ是适应于各个任务上的参数(这里和原本的MAML是反着的,原本是 ϕ \phi ϕ才是元学习参数,读者自行转换一下)。联系上文,就是说各个任务的计算出来的梯度都各不相同,这时沿着各个方向的梯度进行调整就会得到很大的改变。接下来的问题是如何基于这种思想去传递每一个任务的梯度。下面我们以数学公式表示:
对于每一个任务,我们从原本元学习参数 θ \theta θ进行梯度下降后可以得到针对此任务最敏感的参数:
我们的目标是结合所有任务的梯度信息,也就是优化 θ \theta θ在样本任务性能。数学公式表达如下:
也就是对于所有任务来说,我们从元学习参数变为各个任务的参数后,让整一个loss达到最小。算法如下;
由于这个算法细节很简单同时又非常能体现meta-learning的思想,而且实验效果和利用率都很高,所以我简单来讲一下这个算法,我会对照着代码进行讲解。
首先是随机初始化参数 θ \theta θ,这个没什么好说的,构建模型你自己初始化一个,然后接下来是从我们的任务分布依次选一些task构成support size。这里要提到omniglot数据集,这个数据集包括1000多个类,每个类只有10几张图片。我们可以根据这个选择多个任务,例如5way10shot就是每一个任务有5类,每一类有10张图片,所以每一个任务就有50个图片,我们选取多个任务组成,就有n*50个数据啦。
算法7是一次梯度后,去看一下准确率损失咋样,代码如下:
# 算出预测值
y_hat = self.net(x_spt[i], params=None, bn_training=True) # (ways * shots, ways)
loss = F.cross_entropy(y_hat, y_spt[i])
grad = torch.autograd.grad(loss, self.net.parameters())
## 将梯度和参数\theta一一对应起来
tuples = zip(grad, self.net.parameters())
# fast_weights这一步相当于求了一个\theta - \alpha*\nabla(L)
#这里采用这种形式而不是loss.backward()是因为以前的参数最后还需要,所以先存一下新的参数
fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))
# 在query集上测试,计算准确率
# 这一步使用更新前的数据
with torch.no_grad():
y_hat = self.net(x_qry[i], self.net.parameters(), bn_training=True)
loss_qry = F.cross_entropy(y_hat, y_qry[i])
loss_list_qry[0] += loss_qry
pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)
correct = torch.eq(pred_qry, y_qry[i]).sum().item()
correct_list[0] += correct
# 使用更新后的数据在query集上测试。
with torch.no_grad():
y_hat = self.net(x_qry[i], fast_weights, bn_training=True)
loss_qry = F.cross_entropy(y_hat, y_qry[i])
loss_list_qry[1] += loss_qry
pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75)
correct = torch.eq(pred_qry, y_qry[i]).sum().item()
correct_list[1] += correct
到第8步,我们根据 θ \theta θ求出 ϕ \phi ϕ,就使用刚刚我们上面用的fast_weight进行更新
for k in range(1, self.update_step):
#
y_hat = self.net(x_spt[i], params=fast_weights, bn_training=True)
loss = F.cross_entropy(y_hat, y_spt[i])
grad = torch.autograd.grad(loss, fast_weights)
tuples = zip(grad, fast_weights)
fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))
if k < self.update_step - 1:
with torch.no_grad():
y_hat = self.net(x_qry[i], params=fast_weights, bn_training=True)
loss_qry = F.cross_entropy(y_hat, y_qry[i])
loss_list_qry[k + 1] += loss_qry
else:
y_hat = self.net(x_qry[i], params=fast_weights, bn_training=True)
loss_qry = F.cross_entropy(y_hat, y_qry[i])
loss_list_qry[k + 1] += loss_qry
with torch.no_grad():
pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)
correct = torch.eq(pred_qry, y_qry[i]).sum().item()
correct_list[k + 1] += correct
# print('hello')
首先是根据support set对fastweight进行变换更新,更新后在query set上求出损失,累积梯度(这里update_step可以自行调节,而每次这里只记录最后一次的损失梯度避免梯度爆炸),这里尤其注意是在query上进行损失计算和梯度累积。
最后每一个有梯度的损失都记录在loss_qry[-1]上了,我们对他loss.backward()更新即可。
loss_qry = loss_list_qry[-1] / task_num
self.meta_optim.zero_grad() # 梯度清零
loss_qry.backward()
self.meta_optim.step()
我画一个图来表示一下这个更新过程
(中途中也求了在update中在query set的损失,但其实只是展示他更新速度有多快,并没有计入梯度,不影响学习过程)
元学习中一个很重要的算法MAML,给出了讲解以及对应的代码,现在继续深入这个算法。
核心和之前是一样的,对于监督学习的话,单输入单输出,我们的损失函数可以定位为:
可以发现,之前的MAML算法我们用到了二阶导数,然而二阶导数会增加我们的计算量,因此作者就想用一阶近似去模拟二阶,看是否能代替。优化定义如下:
sg表示来停止梯度的操作,这种近似是将参数更新视为一个常数( θ t o θ + c \theta\ to\ \theta+c θ to θ+c),然后反向去传播这个新的性能任务