正在更新中。。。。。。
参考博客:
https://www.mutepad.com/2019/05/24/maml-%E8%AE%BA%E6%96%87%E5%8F%8A%E4%BB%A3%E7%A0%81%E9%98%85%E8%AF%BB%E7%AC%94%E8%AE%B0/
知乎:https://www.mutepad.com/2019/05/24/maml-%E8%AE%BA%E6%96%87%E5%8F%8A%E4%BB%A3%E7%A0%81%E9%98%85%E8%AF%BB%E7%AC%94%E8%AE%B0/
代码:https://github.com/dragen1860/MAML-Pytorch
论文地址:https://link.zhihu.com/?target=https%3A//arxiv.org/abs/1703.03400
首先看下网络结构:
可以看到是一个标准的四层卷积神经网络,size均为3*3*32
,后接relu、batchnorm以及maxpooling,最后将卷及结果扁平化。
config = [
('conv2d', [32, 3, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 1, 0]),
('flatten', []),
('linear', [args.n_way, 32 * 5 * 5])
]
maml = Meta(args, config).to(device)
maml是代码作者封装好的类,继承自torch.nn.Module。如核心代码如下所示,其中的forward方法就很清晰地体现了论文中的算法计算过程。由于是处理图像分类问题,loss选择的是交叉熵计算方式。这里作者是照做论文中的实现,只进行了一次梯度迭代,在对每一个task基于support set(x_spt)求出loss,计算梯度后,更新出新的参数 θ’:
for k in range(1, self.update_step):
# 1. run the i-th task and compute loss for k=1~K-1
logits = self.net(x_spt[i], fast_weights, bn_training=True)
loss = F.cross_entropy(logits, y_spt[i])
# 2. compute grad on theta_pi
grad = torch.autograd.grad(loss, fast_weights)
# 3. theta_pi = theta_pi – train_lr * grad
fast_weights = list(map(lambda p: p[1] – self.update_lr * p[0], zip(grad, fast_weights)))
logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
# loss_q will be overwritten and just keep the loss_q on last update step.
loss_q = F.cross_entropy(logits_q, y_qry[i])
losses_q[k + 1] += loss_q
with torch.no_grad():
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, y_qry[i]).sum().item() # convert to numpy
corrects[k + 1] = corrects[k + 1] + correct
之后计算这个task基于query set(x_qry)的loss并求和,在下述代码中完成第二次参数的梯度更新:
# end of all tasks
# sum over all losses on query set across all tasks
loss_q = losses_q[-1] / task_num
# optimize theta parameters
self.meta_optim.zero_grad()
loss_q.backward()
self.meta_optim.step()
在网络运行过程中,数据集被分成了两部分。mini和mini_test对应的是meta-train set和meta-test set,代码作者封装了MiniImagenet类,该类继承自pytorch的Dataset类,并重写了__getitem()__方法。这样,在迭代该类时,会获得四个返回值,分别是support set的样本、support set的标签、query set的样本、query set的标签。
mini = MiniImagenet('/home/i/tmp/MAML-Pytorch/miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
k_query=args.k_qry,
batchsz=10000, resize=args.imgsz)
mini_test = MiniImagenet('/home/i/tmp/MAML-Pytorch/miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
k_query=args.k_qry,
batchsz=100, resize=args.imgsz)
接下来是整个训练过程的流程,在每一步中,都会在meta-train set上进行maml的核心算法,然后每过500个epoch,会将参数放到meta-test set上进行测试,然后对测试集中的每个task做fine-tune。
for epoch in range(args.epoch//10000):
# 在训练集中取一个batch的task进行训练
db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)
for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
accs = maml(x_spt, y_spt, x_qry, y_qry)
if step % 30 == 0:
print(‘step:’, step, ‘\ttraining acc:’, accs)
if step % 500 == 0: # 评估阶段
db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
accs_all_test = []
for x_spt, y_spt, x_qry, y_qry in db_test:
x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
accs_all_test.append(accs)
accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
print(‘Test acc:’, accs)
最后是代码的实际运行截图,在5-way 1-shot的训练规格下准确率并不是很高,距论文上的效果还有一定的差距,应该是调参的问题: