MAML-Pytorch代码阅读笔记

正在更新中。。。。。。
参考博客:
https://www.mutepad.com/2019/05/24/maml-%E8%AE%BA%E6%96%87%E5%8F%8A%E4%BB%A3%E7%A0%81%E9%98%85%E8%AF%BB%E7%AC%94%E8%AE%B0/
知乎:https://www.mutepad.com/2019/05/24/maml-%E8%AE%BA%E6%96%87%E5%8F%8A%E4%BB%A3%E7%A0%81%E9%98%85%E8%AF%BB%E7%AC%94%E8%AE%B0/

代码:https://github.com/dragen1860/MAML-Pytorch
论文地址:https://link.zhihu.com/?target=https%3A//arxiv.org/abs/1703.03400

MAML-pytorch:

首先看下网络结构:
可以看到是一个标准的四层卷积神经网络,size均为3*3*32,后接relu、batchnorm以及maxpooling,最后将卷及结果扁平化。

config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('flatten', []),
        ('linear', [args.n_way, 32 * 5 * 5])
]

maml = Meta(args, config).to(device)

maml是代码作者封装好的类,继承自torch.nn.Module。如核心代码如下所示,其中的forward方法就很清晰地体现了论文中的算法计算过程。由于是处理图像分类问题,loss选择的是交叉熵计算方式。这里作者是照做论文中的实现,只进行了一次梯度迭代,在对每一个task基于support set(x_spt)求出loss,计算梯度后,更新出新的参数 θ’:

for k in range(1, self.update_step):
        # 1. run the i-th task and compute loss for k=1~K-1
        logits = self.net(x_spt[i], fast_weights, bn_training=True)
        loss = F.cross_entropy(logits, y_spt[i])
        # 2. compute grad on theta_pi
        grad = torch.autograd.grad(loss, fast_weights)
        # 3. theta_pi = theta_pi – train_lr * grad
        fast_weights = list(map(lambda p: p[1] – self.update_lr * p[0], zip(grad, fast_weights)))
        logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
        # loss_q will be overwritten and just keep the loss_q on last update step.
        loss_q = F.cross_entropy(logits_q, y_qry[i])
        losses_q[k + 1] += loss_q
        with torch.no_grad():
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item() # convert to numpy
                corrects[k + 1] = corrects[k + 1] + correct

之后计算这个task基于query set(x_qry)的loss并求和,在下述代码中完成第二次参数的梯度更新:

# end of all tasks
# sum over all losses on query set across all tasks
loss_q = losses_q[-1] / task_num
 
# optimize theta parameters
self.meta_optim.zero_grad()
loss_q.backward()
self.meta_optim.step()

在网络运行过程中,数据集被分成了两部分。mini和mini_test对应的是meta-train set和meta-test set,代码作者封装了MiniImagenet类,该类继承自pytorch的Dataset类,并重写了__getitem()__方法。这样,在迭代该类时,会获得四个返回值,分别是support set的样本、support set的标签、query set的样本、query set的标签。

mini = MiniImagenet('/home/i/tmp/MAML-Pytorch/miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=10000, resize=args.imgsz)
mini_test = MiniImagenet('/home/i/tmp/MAML-Pytorch/miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                        k_query=args.k_qry,
                        batchsz=100, resize=args.imgsz)

接下来是整个训练过程的流程,在每一步中,都会在meta-train set上进行maml的核心算法,然后每过500个epoch,会将参数放到meta-test set上进行测试,然后对测试集中的每个task做fine-tune。

for epoch in range(args.epoch//10000):
        # 在训练集中取一个batch的task进行训练
        db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)
 
        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
                x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
 
                accs = maml(x_spt, y_spt, x_qry, y_qry)
                if step % 30 == 0:
                        print(‘step:, step, ‘\ttraining acc:, accs)
 
                if step % 500 == 0: # 评估阶段
                        db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
                        accs_all_test = []

                        for x_spt, y_spt, x_qry, y_qry in db_test:
                                x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
 
                                accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                                accs_all_test.append(accs)
 
                        accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                        print(‘Test acc:, accs)

运行截图:

最后是代码的实际运行截图,在5-way 1-shot的训练规格下准确率并不是很高,距论文上的效果还有一定的差距,应该是调参的问题:
MAML-Pytorch代码阅读笔记_第1张图片

你可能感兴趣的:(Meta,Learning学习笔记)