Meta Learning,又称为 learning to learn,Meta Learning希望使得模型获取一种“学会学习”的能力,使其可以在获取已有“知识”的基础上快速学习新的任务,对于新的类别,只需要少量的样本就能快速学习(Few-shot Learning)。
Few-shot Learning 是 Meta Learning 在监督学习领域的应用。
早期研究都基于以下两个图像数据集:
Omniglot:https://github.com/brendenlake/omniglot
包含1623个不同的火星文字符,每个字符包含20个手写的case
miniImageNet:https://github.com/yaoyao-liu/mini-imagenet-tools
包含100类共60000张彩色图片,其中每类有600个样本
2017年发表,到2022年7月12日已经收获493的引用 https://arxiv.org/pdf/1703.03400.pdf
MAML与其说是一个深度学习模型,倒不如说是一个框架,提供一个meta-learner(MAML的精髓所在,learing to learn)用于训练base-learner(根据新数据实际用于预测任务的模型)。
绝大多数深度学习模型都可以作为base-learner无缝嵌入MAML中。
MAML的目的是获取一组更好的模型初始化参数(即让模型自己学会初始化)
可以这么理解:假设我们目前有3个tasks,分别为T 1 , T 2 , T 3 。按照以前模型的训练方式,首先,我们随机初始化模型参数θ,然后开始训练任务T 1 ,接着最小化损失函数L 来更新网络的参数,这样我们就会得到新的参数θ 1 。同理,我们可以接着更新其他两个任务。但以前模型的训练方式,是每个任务都是随机初始化θ开始,每个任务都是独立的。如果我们把三个任务初始化的θ到公用的位置,则不需要更多的梯度更新步骤。MAML就是做这件事的。
构建的任务分为训练任务(Train Task),测试任务(Test Task)。
每个任务都有自己的训练集(Support Set)、测试集( Query Set)
N-ways,K-shot(数据中包含N个类别,每个类别有K个样本)
以训练 miniImage 数据集为例,按4:1划分数据集
Train Task:从训练集(80 个类,每类 600 个样本)中随机采样 5 个类,每个类 1 个样本(5-way 1-shot),构成Support Set,去学习 learner;然后从训练集的样本(采出的5 个类,每类剩下的样本)中抽 15 个样本采样构成Query Set,用来获得 learner 的 loss,去学习 meta leaner。
Test Task:(20 个类,每类 600 个样本)中随机采样5个类,每个类1 个样本(与training阶段一致,5-way 1-shot),构成支撑集 Support Set,去学习 learner;然后从测试集剩余的样本(采出的5 个类,每类剩下的样本)中抽 15 个样本采样构成 Query Set,用来获得 learner 的参数,进而得到预测的类别概率。
## 网络构建部分: refer: https://github.com/dragen1860/MAML-TensorFlow
#################################################
# 任务描述:5-ways,1-shot图像分类任务,图像统一处理成 84 * 84 * 3 = 21168的尺寸。
# support set:5 * 1
# query set:5 * 15
# 训练取1个batch的任务:batch size:4
# 对训练任务进行训练时,更新5次:K = 5
#################################################
print(support_x) # (4, 5, 21168)
print(query_x) # (4, 75, 21168)
print(support_y) # (4, 5, 5)
print(query_y) # (4, 75, 5)
print(meta_batchsz) # 4
print(K) # 5
model = MAML()
model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')
class MAML:
def __init__(self):
pass
def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'):
"""
:param support_xb: [4, 5, 84*84*3]
:param support_yb: [4, 5, n-way]
:param query_xb: [4, 75, 84*84*3]
:param query_yb: [4, 75, n-way]
:param K: 训练任务的网络更新步数
:param meta_batchsz: 任务数,4
"""
self.weights = self.conv_weights() # 创建或者复用网络参数;训练任务对应的网络复用meta网络的参数
training = True if mode is 'train' else False
def meta_task(input):
"""
:param support_x: [setsz, 84*84*3] (5, 21168)
:param support_y: [setsz, n-way] (5, 5)
:param query_x: [querysz, 84*84*3] (75, 21168)
:param query_y: [querysz, n-way] (75, 5)
:param training: training or not, for batch_norm
:return:
"""
support_x, support_y, query_x, query_y = input
query_preds, query_losses, query_accs = [], [], [] # 子网络更新K次,记录每一次queryset的结果
## 第0次对网络进行更新
support_pred = self.forward(support_x, self.weights, training) # 前向计算support set
support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) # support set loss
support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
tf.argmax(support_y, axis=1))
grads = tf.gradients(support_loss, list(self.weights.values())) # 计算support set的梯度
gvs = dict(zip(self.weights.keys(), grads))
# 使用support set的梯度计算的梯度更新参数,theta_pi = theta - alpha * grads
fast_weights = dict(zip(self.weights.keys(), \
[self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()]))
# 使用梯度更新后的参数对quert set进行前向计算
query_pred = self.forward(query_x, fast_weights, training)
query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
query_preds.append(query_pred)
query_losses.append(query_loss)
# 第1到 K-1次对网络进行更新
for _ in range(1, K):
loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training),
labels=support_y)
grads = tf.gradients(loss, list(fast_weights.values()))
gvs = dict(zip(fast_weights.keys(), grads))
fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key]
for key in fast_weights.keys()]))
query_pred = self.forward(query_x, fast_weights, training)
query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
# 子网络更新K次,记录每一次queryset的结果
query_preds.append(query_pred)
query_losses.append(query_loss)
for i in range(K):
query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1),
tf.argmax(query_y, axis=1)))
result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
return result
# return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K]
result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb),
dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn')
support_pred_tasks, support_loss_tasks, support_acc_tasks, \
query_preds_tasks, query_losses_tasks, query_accs_tasks = result
if mode is 'train':
self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz
self.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz
for j in range(K)]
self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz
self.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz
for j in range(K)]
# 更新meta网络,只使用了第 K步的query loss。这里应该是个超参,更新几步可以调调
optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')
gvs = optimizer.compute_gradients(self.query_losses[-1])
# def ********
1.原论文:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networkshttps://arxiv.org/pdf/1703.03400.pdf
2.小样本学习(Few-shot Learning)综述小样本学习(Few-shot Learning)综述
3.一文入门元学习(Meta-Learning)(附代码)一文入门元学习(Meta-Learning)(附代码) - 知乎
4.Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解 - 知乎
5.从代码上解析Meta-learning从代码上解析Meta-learning_洛克-李的博客-CSDN博客