Meta Learning入门之MAML实现Few-Shot Learning(Ominglot部分论文复现)

最近看了李宏毅老师的MAML课,尝试了一下自己implement from strach:关于Ominglot数据集的5-way 1-shot分类。

先挂一下参考的资源:

李宏毅的Lectures:https://www.youtube.com/watch?v=EkAqYbpCYAc

论文原文:https://arxiv.org/abs/1703.03400

两篇知乎笔记:

https://zhuanlan.zhihu.com/p/136975128

https://zhuanlan.zhihu.com/p/66926599

个人用一句话概括MAML的灵魂,大概就是对于很多同类的任务(分类,回归,RL等)可以用一些任务来训练一个较优的初始化参数,以便做其他任务时可以快速收敛。

与迁移学习不同的是,迁移学习时预训练找的是任务的最优参数,而MAML找的是潜力最大的参数(即可以很少次梯度下降便使loss收敛的参数)。

能力和时间有限,所以利用了first-order approximation,避免了Hessian矩阵的计算。

下面贴一下代码:

数据预处理(得到训练集和测试集,shape分别是(1200,20,1,28,28)和(423,20,1,28,28)):

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

import os
from PIL import Image
import numpy as np


transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor()
])

trainset = torchvision.datasets.Omniglot(
    root='./data',
    download=True,
    background=True,
    transform=transform
)
testset = torchvision.datasets.Omniglot(
    root='./data',
    download=True,
    background=False,
    transform=transform
)
''''''
dataset = trainset + testset
print(len(dataset))
tmp = dataset[5][0].squeeze(0)
print(tmp.shape, dataset[5][1])
plt.imshow(tmp, plt.cm.gray)
plt.show()

root_dir = os.getcwd()

base_dir = root_dir + '/data' + '/omniglot-py/images_background'
base_folders = os.listdir(base_dir)
for i, category_name in enumerate(base_folders):
    num_dir = base_dir + '/' + category_name
    numbers = os.listdir(num_dir)
    for j, number in enumerate(numbers):
        file_dir = num_dir + '/' + number
        for k, filename in enumerate(os.listdir(file_dir)):
            img = Image.open(file_dir + '/' + filename).convert('L')
            img_tensor = transform(img)
            if k == 0:
                cate_tensor = img_tensor.unsqueeze(0)
            else:
                cate_tensor = torch.cat((cate_tensor, img_tensor.unsqueeze(0)), dim=0)
        if i == 0 and j == 0 and "images_background" in base_dir:
            dataset = cate_tensor.unsqueeze(0)
        else:
            dataset = torch.cat((dataset, cate_tensor.unsqueeze(0)), dim=0)

base_dir = root_dir + '/data' + '/omniglot-py/images_evaluation'
base_folders = os.listdir(base_dir)
for i, category_name in enumerate(base_folders):
    num_dir = base_dir + '/' + category_name
    numbers = os.listdir(num_dir)
    for j, number in enumerate(numbers):
        file_dir = num_dir + '/' + number
        for k, filename in enumerate(os.listdir(file_dir)):
            img = Image.open(file_dir + '/' + filename).convert('L')
            img_tensor = transform(img)
            if k == 0:
                cate_tensor = img_tensor.unsqueeze(0)
            else:
                cate_tensor = torch.cat((cate_tensor, img_tensor.unsqueeze(0)), dim=0)
        if i == 0 and j == 0 and "images_background" in base_dir:
            dataset = cate_tensor.unsqueeze(0)
        else:
            dataset = torch.cat((dataset, cate_tensor.unsqueeze(0)), dim=0)


print(dataset.shape)
dataset = dataset.numpy()
np.save("train_data.npy", dataset[:1200])
np.save("test_data.npy", dataset[1200:])

训练与测试(由于没仔细看原作者使怎么测试的,所以自己在测试集中采样了100组任务(每组32个),算一下每组的正确率):

import os
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt

from copy import deepcopy

print(torch.cuda.get_device_name())

trainset = np.load("train_data.npy")
testset = np.load("test_data.npy")

print(trainset.shape, testset.shape)

meta_batch_size = 32
alpha = 0.04
beta = 0.0001

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=0)
        self.bn1 = nn.BatchNorm2d(64)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)

        self.fc = nn.Linear(64 * 4 * 4, 128)
        self.out = nn.Linear(128, 5)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, kernel_size=2, stride=2, padding=0)

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)

        x = F.relu(self.bn4(self.conv4(x)))

        x = x.reshape(-1, 64 * 4 * 4)
        x = self.fc(x)
        x = self.out(x)

        return x

def task_sample(mode):
    set_len = 1200 if mode == "train" else 423
    curset = trainset if mode == "train" else testset
    categories = random.sample(range(set_len), 5)
    # categories = [0, 1, 3, 50, 100]
    spt_x = None
    qry_x = None
    spt_y = torch.tensor([0, 1, 2, 3, 4])
    qry_y = torch.tensor([0, 1, 2, 3, 4])
    for _ in range(5):
        i = categories[_]
        j, k = random.sample(range(20), 2)

        cur_spt = torch.from_numpy(curset[i][j])
        cur_qry = torch.from_numpy(curset[i][k])
        # print("category:", i, "numbers:", j, k)
        if _ == 0:
            spt_x = cur_spt.unsqueeze(0)
            qry_x = cur_qry.unsqueeze(0)
        else:
            spt_x = torch.cat([spt_x, cur_spt.unsqueeze(0)], dim=0)
            qry_x = torch.cat([qry_x, cur_qry.unsqueeze(0)], dim=0)
    # print(spt_x.shape, spt_y.shape, qry_x.shape, qry_y.shape)
    return spt_x, spt_y, qry_x, qry_y

class BaseLearner():
    def __init__(self, learning_rate, model):
        self.model = deepcopy(model)
        self.alpha = learning_rate
        self.opt = None

    def update(self, model, learning_rate):
        self.model = deepcopy(model)
        self.opt = optim.SGD(self.model.parameters(), lr=learning_rate)

    def train_task(self):
        correct = 0
        self.model = self.model.cuda()
        spt_x, spt_y, qry_x, qry_y = task_sample("train")
        spt_x, spt_y, qry_x, qry_y = spt_x.cuda(), spt_y.cuda(), qry_x.cuda(), qry_y.cuda()
        # paras = [ele for ele in self.model.parameters()]

        ret = self.model(spt_x)
        loss = F.cross_entropy(ret, spt_y)
        self.opt.zero_grad()
        loss.backward()
        # grads = [ele.grad for ele in self.model.parameters()]
        self.opt.step()

        ret = self.model(qry_x)
        loss = F.cross_entropy(ret, qry_y)
        self.opt.zero_grad()
        loss.backward()

        correct += ret.argmax(dim=1).eq(qry_y).sum().item()

        self.model = self.model.cpu()
        # loss, grads, correct numbers
        return loss.item(), [ele.grad for ele in self.model.parameters()], correct

    def test_task(self):
        correct = 0
        self.model = self.model.cuda()
        spt_x, spt_y, qry_x, qry_y = task_sample("test")
        spt_x, spt_y, qry_x, qry_y = spt_x.cuda(), spt_y.cuda(), qry_x.cuda(), qry_y.cuda()

        for i in range(1):
            ret = self.model(spt_x)
            loss = F.cross_entropy(ret, spt_y)
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()

        ret = self.model(qry_x)
        loss = F.cross_entropy(ret, qry_y)
        # print("Loss:", loss.item())
        correct += ret.argmax(dim=1).eq(qry_y).sum().item()
        self.model = self.model.cpu()
        # print("Accuracy:", correct / 5, "\n")
        return loss.item(), correct

class MetaLearner():
    def __init__(self, learning_rate, batch_size):
        self.model = Net()
        self.beta = learning_rate
        self.meta_batch_size = batch_size
        self.BL = BaseLearner(alpha, self.model)
        self.train_losses = list()

    def train_one_step(self):
        grads = list()
        losses = list()
        total_correct = 0
        for batch_id in range(self.meta_batch_size):
            self.BL.update(self.model, self.BL.alpha)
            cur = self.BL.train_task()
            grads.append(cur[1])
            losses.append(cur[0])
            total_correct += cur[2]
        # update the meta model
        paras = [para for para in self.model.named_parameters()]
        for batch_id in range(self.meta_batch_size):
            for i in range(len(paras)):
                # if "bn" not in paras[i][0]:
                # if batch_id == 0: print(paras[i][0])
                paras[i][1].data = paras[i][1].data - self.beta * grads[batch_id][i].data

        return sum(losses) / self.meta_batch_size, total_correct / (self.meta_batch_size * 5)

    def train(self, epochs):
        for meta_epoch in range(epochs):
            cur_loss, acc = self.train_one_step()
            self.train_losses.append(cur_loss)
            if (meta_epoch + 1) % 1000 == 0:
                print("Meta Training Epoch:", meta_epoch+1)
                print("Loss:", cur_loss)
            # print("Train Accuracy:", acc)

    def test_one_step(self):
        total_correct = 0
        mp = [para for para in self.model.parameters()]
        for batch_id in range(self.meta_batch_size):
            # print("Test task:", batch_id+1)
            self.BL.update(self.model, self.BL.alpha)
            cur = self.BL.test_task()
            total_correct += cur[1]

        return total_correct / (self.meta_batch_size * 5)

    def test(self, epochs):
        for test_round in range(epochs):
            acc = self.test_one_step()
            print("Test Round:", test_round+1)
            # print("Loss:", cur_loss)
            print("Test Accuracy:", acc)

ML = MetaLearner(beta, meta_batch_size)

ML.train(20000)
plt.plot(ML.train_losses)

ML.test(100)

训练结果

由于是白嫖的colab显卡,只跑了20000个epoch(论文训练了60000),最终测试正确率大概是90%(论文98%)。

Meta Training Epoch: 1000
Loss: 1.0540012400597334
Meta Training Epoch: 2000
Loss: 0.7158879199996591
Meta Training Epoch: 3000
Loss: 0.6014841219875962
Meta Training Epoch: 4000
Loss: 0.5044218171387911
Meta Training Epoch: 5000
Loss: 0.4403148274868727
Meta Training Epoch: 6000
Loss: 0.28830910232500173
Meta Training Epoch: 7000
Loss: 0.30838979699183255
Meta Training Epoch: 8000
Loss: 0.16489165458187927
Meta Training Epoch: 9000
Loss: 0.2780265275214333
Meta Training Epoch: 10000
Loss: 0.30111221893457696
Meta Training Epoch: 11000
Loss: 0.2760705396740377
Meta Training Epoch: 12000
Loss: 0.27776027111394797
Meta Training Epoch: 13000
Loss: 0.21309451176421135
Meta Training Epoch: 14000
Loss: 0.21287523438513745
Meta Training Epoch: 15000
Loss: 0.2973926745034987
Meta Training Epoch: 16000
Loss: 0.2408616042957874
Meta Training Epoch: 17000
Loss: 0.15259935014910297
Meta Training Epoch: 18000
Loss: 0.14492448499731836
Meta Training Epoch: 19000
Loss: 0.19298083511239383
Meta Training Epoch: 20000
Loss: 0.287610400642734

Meta Learning入门之MAML实现Few-Shot Learning(Ominglot部分论文复现)_第1张图片

......
Test Round: 90
Test Accuracy: 0.86875
Test Round: 91
Test Accuracy: 0.8625
Test Round: 92
Test Accuracy: 0.9125
Test Round: 93
Test Accuracy: 0.90625
Test Round: 94
Test Accuracy: 0.91875
Test Round: 95
Test Accuracy: 0.89375
Test Round: 96
Test Accuracy: 0.9375
Test Round: 97
Test Accuracy: 0.90625
Test Round: 98
Test Accuracy: 0.91875
Test Round: 99
Test Accuracy: 0.9125
Test Round: 100
Test Accuracy: 0.91875

大致baseline应该算是出来的,之后优化了再更。

你可能感兴趣的:(DL/RL自学笔记,Meta,Learning,MAML,小样本学习)