MAML模型无关的元学习代码完整复现(Pytorch版)

1 引言

元学习是今年来新起的一种深度学习任务,它主要是想训练出具有强学习能力的神经网络。元学习领域一开始是一个小众的领域,之前很多年都没有很好的进展,直到Finn, C.在就读博士期间发表了一篇元学习的论文,也就是大名鼎鼎的MAML,它在回归,分类,强化学习三个任务上都达到了当时最好的性能。

我曾经在半年前发表过一篇MAML的学习笔记,博文地址点这里。

MAML出现之后算是掀起来了一波研究元学习的浪潮,此后改编MAML的论文层出不穷,但都没有实质性的突破。接下来我将引用我学习笔记中的语句来简要介绍一下MAML。

MAML主要是学习出模型的初始参数,使得这个参数在新任务上经过少量的迭代更新之后就能使模型达到最好的效果。过去的方法一般是学习出一个迭代函数或者一个学习规则。MAML没有新增参数,也没有对模型提出任何约束。MAML可以看作是最大化损失函数在新任务上的灵敏度,从而当参数只有很小的改编时,损失函数也能大幅减小。

由于元学习模型天然的快速训练出好的模型,所以其主要用于小样本学习之中。元学习的论文中也大多将小样本学习任务作为论文实验。

2 数据集

本文的复现用到的数据集小样本领域的通用数据集Omniglot,数据集的地址可以在我的github中找到omniglot_standard.zip。
以下引用@心之宙对omniglot的介绍:

Omniglot 一般会被戏称为 MNIST 的转置,大家可以想想为什么?Omniglot 数据集包含来自 50个不同国家的字母表的 1623 个不同手写字符。每一个字符都是由 20个不同的人通过亚马逊的 Mechanical Turk 在线绘制的。
Omniglot 数据集总共包含 50个不同国家的字母表。我们通常将这些分成一组包含 30个字母表的背景(background)集和一组包含 20 个字母表的评估(evaluation)集。
更具挑战性的表示学习任务是使用较小的背景集 “background small 1” 和 “background small 2”。每一个都只包含 5个字母, 更类似于一个成年人在学习一般的字符时可能遇到的经验。

本文的复现主要基于omniglot的标准集。

3 代码分段详解

3.1 数据预处理

首先对zip文件进行解压,解压后可以在python子文件夹中获得如下数据集:

import torch
import numpy as np
import os
import zipfile

root_path = './../datasets'
processed_folder =  os.path.join(root_path)

zip_ref = zipfile.ZipFile(os.path.join(root_path,'omniglot_standard.zip'), 'r')
zip_ref.extractall(root_path)
zip_ref.close()
然后对图片进行预处理
# 数据预处理
root_dir = './../datasets/omniglot/python'

import torchvision.transforms as transforms
from PIL import Image

'''
an example of img_items:
( '0709_17.png',
  'Alphabet_of_the_Magi/character01',
  './../datasets/omniglot/python/images_background/Alphabet_of_the_Magi/character01')
'''
def find_classes(root_dir):
    img_items = []
    for (root, dirs, files) in os.walk(root_dir): 
        for file in files:
            if (file.endswith("png")):
                r = root.split('/')
                img_items.append((file, r[-2] + "/" + r[-1], root))
    print("== Found %d items " % len(img_items))
    return img_items

## 构建一个词典{class:idx}
def index_classes(items):
    class_idx = {}
    count = 0
    for item in items:
        if item[1] not in class_idx:
            class_idx[item[1]] = count
            count += 1
    print('== Found {} classes'.format(len(class_idx)))
    return class_idx
        

img_items =  find_classes(root_dir)
class_idx = index_classes(img_items)


temp = dict()
for imgname, classes, dirs in img_items:
    img = '{}/{}'.format(dirs, imgname)
    label = class_idx[classes]
    transform = transforms.Compose([lambda img: Image.open(img).convert('L'),
                              lambda img: img.resize((28,28)),
                              lambda img: np.reshape(img, (28,28,1)),
                              lambda img: np.transpose(img, [2,0,1]),
                              lambda img: img/255.
                              ])
    img = transform(img)
    if label in temp.keys():
        temp[label].append(img)
    else:
        temp[label] = [img]
print('begin to generate omniglot.npy')
## 移除标签信息,每个标签包含20个样本
img_list = []
for label, imgs in temp.items():
    img_list.append(np.array(imgs))
img_list = np.array(img_list).astype(np.float) # [[20 imgs],..., 1623 classes in total]
print('data shape:{}'.format(img_list.shape)) # (1623, 20, 1, 28, 28)
temp = []
np.save(os.path.join(root_dir, 'omniglot.npy'), img_list)
print('end.')

3.3 构造训练集和测试集

img_list = np.load(os.path.join(root_dir, 'omniglot.npy')) # (1623, 20, 1, 28, 28)
x_train = img_list[:1200]
x_test = img_list[1200:]
num_classes = img_list.shape[0]
datasets = {'train': x_train, 'test': x_test}

然后我们需要将构造一个批量迭代出数据的迭代器,重要部分的实现代码如下:

def next(mode='train'):
    """
    Gets next batch from the dataset with name.
    :param mode: The name of the splitting (one of "train", "val", "test")
    :return:
    """
    # update cache if indexes is larger than len(data_cache)
    if indexes[mode] >= len(datasets_cache[mode]):
        indexes[mode] = 0
        datasets_cache[mode] = load_data_cache(datasets[mode])

    next_batch = datasets_cache[mode][indexes[mode]]
    indexes[mode] += 1

    return next_batch

3.2 构造Base-Learner

由于代码较多,这里只展示重要的代码:

if params is None:
    params = self.vars

weight, bias = params[0], params[1]  # 第1个CONV层
x = F.conv2d(x, weight, bias, stride = 2, padding = 2)

weight, bias = params[2], params[3]  # 第1个BN层
running_mean, running_var = self.vars_bn[0], self.vars_bn[1]
x = F.batch_norm(x, running_mean, running_var, weight=weight,bias =bias, training= bn_training)
x = F.max_pool2d(x,kernel_size=2)  #第1个MAX_POOL层  
x = F.relu(x, inplace = [True])  #第1个relu

CONV 层-> BN 层 -> POOL层 -> ReLU层,以上四个层组成一个块,然后将四个类似的块堆叠起来,结尾接一个Flatten层和一个Linear层。

构造Meta-Learner

以下是Meta-Learner的代码,包括一个测试用的finetunning函数。

class MetaLearner(nn.Module):
    def __init__(self):
        super(MetaLearner, self).__init__()
        self.update_step = 5 ## task-level inner update steps
        self.update_step_test = 5  
        self.net = BaseNet()
        self.meta_lr = 2e-4
        self.base_lr = 4 * 1e-2
        self.inner_lr = 0.4
        self.outer_lr = 1e-2
        self.meta_optim = torch.optim.Adam(self.net.parameters(), lr = self.meta_lr)
        
    def forward(self,x_spt, y_spt, x_qry, y_qry):
        # 初始化
        task_num, ways, shots, h, w = x_spt.size()
        query_size = x_qry.size(1) # 75 = 15 * 5
        loss_list_qry = [0 for _ in range(self.update_step + 1)]
        correct_list = [0 for _ in range(self.update_step + 1)]
        
        for i in range(task_num):
            ## 第0步更新
            y_hat = self.net(x_spt[i], params = None, bn_training=True) # (ways * shots, ways)
            loss = F.cross_entropy(y_hat, y_spt[i]) 
            grad = torch.autograd.grad(loss, self.net.parameters())
            tuples = zip(grad, self.net.parameters()) ## 将梯度和参数\theta一一对应起来
            # fast_weights这一步相当于求了一个\theta - \alpha*\nabla(L)
            fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))
            # 在query集上测试,计算准确率
            # 这一步使用更新前的数据
            with torch.no_grad():
                y_hat = self.net(x_qry[i], self.net.parameters(), bn_training = True)
                loss_qry = F.cross_entropy(y_hat, y_qry[i])
                loss_list_qry[0] += loss_qry
                pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)  # size = (75)
                correct = torch.eq(pred_qry, y_qry[i]).sum().item()
                correct_list[0] += correct
            
            # 使用更新后的数据在query集上测试。
            with torch.no_grad():
                y_hat = self.net(x_qry[i], fast_weights, bn_training = True)
                loss_qry = F.cross_entropy(y_hat, y_qry[i])
                loss_list_qry[1] += loss_qry
                pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)  # size = (75)
                correct = torch.eq(pred_qry, y_qry[i]).sum().item()
                correct_list[1] += correct   
            
            for k in range(1, self.update_step):
                
                y_hat = self.net(x_spt[i], params = fast_weights, bn_training=True)
                loss = F.cross_entropy(y_hat, y_spt[i])
                grad = torch.autograd.grad(loss, fast_weights)
                tuples = zip(grad, fast_weights) 
                fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples))
                    
                y_hat = self.net(x_qry[i], params = fast_weights, bn_training = True)
                loss_qry = F.cross_entropy(y_hat, y_qry[i])
                loss_list_qry[k+1] += loss_qry
                
                with torch.no_grad():
                    pred_qry = F.softmax(y_hat,dim=1).argmax(dim=1)
                    correct = torch.eq(pred_qry, y_qry[i]).sum().item()
                    correct_list[k+1] += correct
#         print('hello')
                
        loss_qry = loss_list_qry[-1] / task_num
        self.meta_optim.zero_grad() # 梯度清零
        loss_qry.backward()
        self.meta_optim.step()
        
        accs = np.array(correct_list) / (query_size * task_num)
        loss = np.array(loss_list_qry) / ( task_num)
        return accs,loss

    
    
    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        assert len(x_spt.shape) == 4
        
        query_size = x_qry.size(0)
        correct_list = [0 for _ in range(self.update_step_test + 1)]
        
        new_net = deepcopy(self.net)
        y_hat = new_net(x_spt)
        loss = F.cross_entropy(y_hat, y_spt)
        grad = torch.autograd.grad(loss, new_net.parameters())
        fast_weights = list(map(lambda p:p[1] - self.base_lr * p[0], zip(grad, new_net.parameters())))
        
        # 在query集上测试,计算准确率
        # 这一步使用更新前的数据
        with torch.no_grad():
            y_hat = new_net(x_qry,  params = new_net.parameters(), bn_training = True)
            pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)  # size = (75)
            correct = torch.eq(pred_qry, y_qry).sum().item()
            correct_list[0] += correct

        # 使用更新后的数据在query集上测试。
        with torch.no_grad():
            y_hat = new_net(x_qry, params = fast_weights, bn_training = True)
            pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)  # size = (75)
            correct = torch.eq(pred_qry, y_qry).sum().item()
            correct_list[1] += correct

        for k in range(1, self.update_step_test):
            y_hat = new_net(x_spt, params = fast_weights, bn_training=True)
            loss = F.cross_entropy(y_hat, y_spt)
            grad = torch.autograd.grad(loss, fast_weights)
            fast_weights = list(map(lambda p:p[1] - self.base_lr * p[0], zip(grad, fast_weights)))
            
            y_hat = new_net(x_qry, fast_weights, bn_training=True)
            
            with torch.no_grad():
                pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1)
                correct = torch.eq(pred_qry, y_qry).sum().item()
                correct_list[k+1] += correct
                
        del new_net
        accs = np.array(correct_list) / query_size
        return accs      

4 全部源码

你可以在我的github中找到全部源码。github账号:miguealanmath
传送门

5 实验结果

我跑了6万轮,但是测试集上最高的准确率只有92%,而作者的论文中达到了98%。很多外国网友也复现了MAML,普遍的印象是,最终的准确率会比作者报告的98%略低几个百分点。同时在训练过程中,MAML多次出现了训练数据过拟合的现象。这应该也是最终在测试集上准确率较低的原因。

有兴趣的朋友可以多调参试试。

PS: 这篇文章的更新可以查看我的另一篇博客

你可能感兴趣的:(Pytorch,元学习,pytorch,神经网络,深度学习,人工智能,过拟合)