meta—learning调研及MAML概述

背景

Meta Learning,又称为 learning to learn,Meta Learning希望使得模型获取一种“学会学习”的能力,使其可以在获取已有“知识”的基础上快速学习新的任务,对于新的类别,只需要少量的样本就能快速学习(Few-shot Learning)。

Few-shot Learning 是 Meta Learning 在监督学习领域的应用。

数据集

早期研究都基于以下两个图像数据集:

Omniglot:https://github.com/brendenlake/omniglot

包含1623个不同的火星文字符,每个字符包含20个手写的case

miniImageNet:https://github.com/yaoyao-liu/mini-imagenet-tools

包含100类共60000张彩色图片,其中每类有600个样本

主流算法

MAML(入门+重要)

2017年发表,到2022年7月12日已经收获493的引用 https://arxiv.org/pdf/1703.03400.pdf

MAML与其说是一个深度学习模型,倒不如说是一个框架,提供一个meta-learner(MAML的精髓所在,learing to learn)用于训练base-learner(根据新数据实际用于预测任务的模型)。

绝大多数深度学习模型都可以作为base-learner无缝嵌入MAML中。

(一)目的

MAML的目的是获取一组更好的模型初始化参数(即让模型自己学会初始化)

可以这么理解:假设我们目前有3个tasks,分别为T 1 , T 2 , T 3 。按照以前模型的训练方式,首先,我们随机初始化模型参数θ,然后开始训练任务T 1 ,接着最小化损失函数L 来更新网络的参数,这样我们就会得到新的参数θ 1 。同理,我们可以接着更新其他两个任务。但以前模型的训练方式,是每个任务都是随机初始化θ开始,每个任务都是独立的。如果我们把三个任务初始化的θ到公用的位置,则不需要更多的梯度更新步骤。MAML就是做这件事的。

(二)专有术语介绍:

构建的任务分为训练任务(Train Task),测试任务(Test Task)。

每个任务都有自己的训练集(Support Set)、测试集( Query Set

N-ways,K-shot(数据中包含N个类别,每个类别有K个样本)

(三)训练流程

以训练 miniImage 数据集为例,按4:1划分数据集

Train Task:从训练集(80 个类,每类 600 个样本)中随机采样 5 个类,每个类 1 个样本(5-way 1-shot),构成Support Set,去学习 learner;然后从训练集的样本(采出的5 个类,每类剩下的样本)中抽 15 个样本采样构成Query Set,用来获得 learner 的 loss,去学习 meta leaner。

Test Task:(20 个类,每类 600 个样本)中随机采样5个类,每个类1 个样本(与training阶段一致,5-way 1-shot),构成支撑集 Support Set,去学习 learner;然后从测试集剩余的样本(采出的5 个类,每类剩下的样本)中抽 15 个样本采样构成 Query Set,用来获得 learner 的参数,进而得到预测的类别概率。

(四)实现代码

## 网络构建部分: refer: https://github.com/dragen1860/MAML-TensorFlow
​
#################################################
# 任务描述:5-ways,1-shot图像分类任务,图像统一处理成 84 * 84 * 3 = 21168的尺寸。
# support set:5 * 1
# query set:5 * 15
# 训练取1个batch的任务:batch size:4
# 对训练任务进行训练时,更新5次:K = 5
#################################################
​
print(support_x) # (4, 5, 21168) 
print(query_x) # (4, 75, 21168)
print(support_y) # (4, 5, 5)
print(query_y) # (4, 75, 5)
print(meta_batchsz) # 4
print(K) # 5
​
model = MAML()
model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')
​
class MAML:
    def __init__(self):
        pass
    def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'):
        """
        :param support_xb: [4, 5, 84*84*3] 
        :param support_yb: [4, 5, n-way]
        :param query_xb:  [4, 75, 84*84*3]
        :param query_yb: [4, 75, n-way]
        :param K:  训练任务的网络更新步数
        :param meta_batchsz: 任务数,4
        """
​
        self.weights = self.conv_weights() # 创建或者复用网络参数;训练任务对应的网络复用meta网络的参数
        training = True if mode is 'train' else False      
        def meta_task(input):
            """
            :param support_x:   [setsz, 84*84*3] (5, 21168)
            :param support_y:   [setsz, n-way] (5, 5)
            :param query_x:     [querysz, 84*84*3] (75, 21168)
            :param query_y:     [querysz, n-way] (75, 5)
            :param training:    training or not, for batch_norm
            :return:
            """
​
            support_x, support_y, query_x, query_y = input
            query_preds, query_losses, query_accs = [], [], [] # 子网络更新K次,记录每一次queryset的结果
 
            ## 第0次对网络进行更新
            support_pred = self.forward(support_x, self.weights, training) # 前向计算support set
            support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) # support set loss
            support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
                                                         tf.argmax(support_y, axis=1))
            grads = tf.gradients(support_loss, list(self.weights.values())) # 计算support set的梯度
            gvs = dict(zip(self.weights.keys(), grads))
            # 使用support set的梯度计算的梯度更新参数,theta_pi = theta - alpha * grads
            fast_weights = dict(zip(self.weights.keys(), \
                    [self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()]))
​
            # 使用梯度更新后的参数对quert set进行前向计算
            query_pred = self.forward(query_x, fast_weights, training)
            query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
            query_preds.append(query_pred)
            query_losses.append(query_loss)
 
            # 第1到 K-1次对网络进行更新
            for _ in range(1, K):           
                loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training),
                                                               labels=support_y)
                grads = tf.gradients(loss, list(fast_weights.values()))
                gvs = dict(zip(fast_weights.keys(), grads))
                fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key]
                                         for key in fast_weights.keys()]))
                query_pred = self.forward(query_x, fast_weights, training)
                query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
                # 子网络更新K次,记录每一次queryset的结果
                query_preds.append(query_pred)
                query_losses.append(query_loss)
​
            for i in range(K):
                query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1),
                                                                tf.argmax(query_y, axis=1)))
            result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
            return result
​
        # return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
        out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K]
        result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb),
                           dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn')
        support_pred_tasks, support_loss_tasks, support_acc_tasks, \
            query_preds_tasks, query_losses_tasks, query_accs_tasks = result
​
        if mode is 'train':
            self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz
            self.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz
                                                    for j in range(K)]
            self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz
            self.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz
                                                    for j in range(K)]
​
            # 更新meta网络,只使用了第 K步的query loss。这里应该是个超参,更新几步可以调调
            optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')
            gvs = optimizer.compute_gradients(self.query_losses[-1])
   # def ********

参考:

1.原论文:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networkshttps://arxiv.org/pdf/1703.03400.pdf

2.小样本学习(Few-shot Learning)综述小样本学习(Few-shot Learning)综述

3.一文入门元学习(Meta-Learning)(附代码)一文入门元学习(Meta-Learning)(附代码) - 知乎

4.Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解 - 知乎

5.从代码上解析Meta-learning从代码上解析Meta-learning_洛克-李的博客-CSDN博客

你可能感兴趣的:(普通人的搬砖日子,人工智能,深度学习,算法)