小样本学习(Few-Shot Learning, FSL)任务,顾名思义,就是能够仅通过一个或几个示例就快速建立对新概念的认知能力。实现小样本学习的方式也有很多,比如:度量学习、数据增强、预训练模型、元学习等等,其中元学习是目前广泛使用的处理小样本学习问题的方法。
元学习(meta learning或learning to learn),也称学会学习,元学习算法能够在学习不同任务的过程中积累经验,从而使得模型能够快速适应新任务。
元学习与一般的监督学习的区别:
一般的监督学习是在训练集上训练出一个函数映射 ,这个函数可以识别出哪张图片是狗,哪张是猫。输入是某张图片,输出是标签。
元学习算法则是让模型学会学习,即在训练集上学习出一个函数 , 可以自动学习出一个函数 ,他可以分辨出哪张图片是狗,哪张是猫。 的输入是一个个的图片集合,输出是函数 , 的输入是某张图片,输出是标签。
可以这么理解使用元学习算法的小样本学习任务:我有一个数据集{大象,老虎,狮子},小样本学习并非是让模型识别出哪个是老虎、大象或者狮子,而是学习出每个类别之间的差异,以便在新的数据集(比如:{汽车、电视、沙发、鼠标})中更好的分类。
为了更形式化评估元学习算法,在分类问题上,元学习的数据形式和一般监督学习的数据形式也有所不同,最小的数据点不再是一张图片,而是一个一个的小任务。每个小任务中有 个类别,每个类别有 张图片,我们称这些任务为N-way K-shot图像分类任务,一共有 个小任务。当K值很小时(一般K<10),该任务就是小样本图像分类任务了。当K=1时,该任务即为单样本图像分类任务。
除此之外我们还需要知道两个重要概念:
如下图,为一个3-way 2-shot图像分类任务,蓝色板块是支持集,绿色的是查询集:
注意,对于元学习而言,上图的3-way 2-shot图像分类任务只是一个数据点,完整的数据集及其训练集-测试集划分如下图所示:
Black Box / Model-based
为小数据集场景专门制定一个能够快速变化参数的模型。代表作有:MANN,MetaNet等。
Optimization Based
通过让模型快速优化自己的参数来实现小样本学习。代表作有:MAML,NAIL,Reptile等。
Metric Based(基于度量的方法)
也是目前主流的方法,通过学习一个Encoder,将数据映射到一个表征空间,然后使用无参的Decoder来进行分类。代表作有Matching Network,Prototypical Network等。
以一个episode为例,其中包含 {狗,虫子,鸟} 三个类别的图片。
1、首先,对支持集中的每张图片使用编码器 进行信息提取,学习到每张图片的Embedding编码表示。(编码器可以选择常规的卷积操作、resnet系列、vit等等)
2、 然后对支持集的每个类别下的Embeddings做均值处理,得到每个类别的原型表示(class Prototype)。
3、对查询集中的图片进行分类。首先使用编码器 将查询集图片进行编码,得到该图片的Embedding向量表示。
4、然后拿着这个Embedding表示和类别原型进行相似度计算,也就是无参的解码过程。(相似度计算的方式很多,可以是欧氏距离或者余弦相似度等)
5、计算完相似度后,往往还需要使用softmax将相似度激活成概率分布。
最终得到查询集图片的分类标签,然后和真实值标签做交叉熵loss,然后梯度反向传播即可完成一个episode的训练。
1、假设原始数据集为D,对于每一个episode,包含一个支持集和一个查询集,即 。
实现方法就是在原始数据集 D 中随机选取N个类别,每个类别选取K张图片,构成支持集,选取Q张图片,构成查询集,这样就组成了一个episode的小数据集。以此类推,构造 个小数据集。
2、 对每张图片,利用Encoder进行特征提取,即
3、计算出支持集中的每个类别的原型(prototype),即
其中, 表示图片 的类别标签。
4、接下来计算每个查询集图片Embedding与每个类别的相似度,即
5、训练用的损失函数,公式如下:
prototypical network是有官方的论文实现的【prototypical-network源码】,而且很多框架里自带原型网络的包,直接调用即可。
但是官方的论文源码比较难看,而且某些场景需要拆解组合时也不方便,因此这里我自己实现了一个精简版的原型网络【我自己的代码复现】。
一般来说way与shot和准确率的关系如下所示:
这个很好理解,一个episode中类别(way)越少,就越容易找出图片之间的异同,比如二分类就比十分类容易一些;一个episode中同一个类别的样本(shot)越多,就越容易找出图片之间的异同。
本质上来说,原型网络就是集成学习中的stacking思想。
一般来说,对于一个稍微有点机器学习基础的小白,要想在一个未知的ML分支领域快速实现一个比较靠谱的实例(比如打一个小样本学习的算法挑战赛,但我从没接触过小样本学习)。总体步骤如下:
参考资料:
综述:
https://arxiv.org/abs/2004.05439
【干货】如何通过元学习解决小样本图像分类任务 - 知乎
Meta Learning(元学习)_哔哩哔哩_bilibili
SOTA:
Few-Shot Classification Leaderboard
https://arxiv.org/abs/1703.05175v2
TechBeat
简单的实例:
https://arxiv.org/abs/1703.05175
元学习——原型网络(Prototypical Networks)_hei653779919的博客-CSDN博客_原型网络
元学习—MAML模型Pytorch实现_hei653779919的博客-CSDN博客