【Pytorch】prototypical network原型网络小样本图像分类简述及其实现

基本概念

小样本学习(Few-Shot Learning, FSL)任务,顾名思义,就是能够仅通过一个或几个示例就快速建立对新概念的认知能力。实现小样本学习的方式也有很多,比如:度量学习、数据增强、预训练模型、元学习等等,其中元学习是目前广泛使用的处理小样本学习问题的方法。

元学习(meta learning或learning to learn),也称学会学习,元学习算法能够在学习不同任务的过程中积累经验,从而使得模型能够快速适应新任务。

元学习与一般的监督学习的区别:

一般的监督学习是在训练集上训练出一个函数映射 f,这个函数可以识别出哪张图片是狗,哪张是猫。输入是某张图片,输出是标签。

元学习算法则是让模型学会学习,即在训练集上学习出一个函数 FF 可以自动学习出一个函数 f^*,他可以分辨出哪张图片是狗,哪张是猫。F 的输入是一个个的图片集合,输出是函数 f^*, f^* 的输入是某张图片,输出是标签。 

【Pytorch】prototypical network原型网络小样本图像分类简述及其实现_第1张图片

可以这么理解使用元学习算法的小样本学习任务:我有一个数据集{大象,老虎,狮子},小样本学习并非是让模型识别出哪个是老虎、大象或者狮子,而是学习出每个类别之间的差异,以便在新的数据集(比如:{汽车、电视、沙发、鼠标})中更好的分类。

小样本学习图片分类的基本思想

为了更形式化评估元学习算法,在分类问题上,元学习的数据形式和一般监督学习的数据形式也有所不同,最小的数据点不再是一张图片,而是一个一个的小任务。每个小任务中有 N 个类别,每个类别有 K 张图片,我们称这些任务为N-way K-shot图像分类任务,一共有 episodes 个小任务。当K值很小时(一般K<10),该任务就是小样本图像分类任务了。当K=1时,该任务即为单样本图像分类任务。

除此之外我们还需要知道两个重要概念:

  • 支持集(Support Set):相当于每个小任务中的训练集,包含N个分类标签,每个标签有K张图片。
  • 查询集(Query Set):相当于每个小任务中的测试集,包含Q张未分类的图片。

如下图,为一个3-way 2-shot图像分类任务,蓝色板块是支持集,绿色的是查询集:

【Pytorch】prototypical network原型网络小样本图像分类简述及其实现_第2张图片

 注意,对于元学习而言,上图的3-way 2-shot图像分类任务只是一个数据点,完整的数据集及其训练集-测试集划分如下图所示:

【Pytorch】prototypical network原型网络小样本图像分类简述及其实现_第3张图片  

元学习流派

Black Box / Model-based
为小数据集场景专门制定一个能够快速变化参数的模型。代表作有:MANN,MetaNet等。

Optimization Based
通过让模型快速优化自己的参数来实现小样本学习。代表作有:MAML,NAIL,Reptile等。

Metric Based(基于度量的方法)
也是目前主流的方法,通过学习一个Encoder,将数据映射到一个表征空间,然后使用无参的Decoder来进行分类。代表作有Matching Network,Prototypical Network等。

Prototypical Network基本原理

以一个episode为例,其中包含 {狗,虫子,鸟} 三个类别的图片。

1、首先,对支持集中的每张图片使用编码器  f_\varphi (x_i) 进行信息提取,学习到每张图片的Embedding编码表示。(编码器可以选择常规的卷积操作、resnet系列、vit等等)

【Pytorch】prototypical network原型网络小样本图像分类简述及其实现_第4张图片

2、 然后对支持集的每个类别下的Embeddings做均值处理,得到每个类别的原型表示(class Prototype)。

【Pytorch】prototypical network原型网络小样本图像分类简述及其实现_第5张图片

  3、对查询集中的图片进行分类。首先使用编码器 f_\varphi (x_i) 将查询集图片进行编码,得到该图片的Embedding向量表示。

【Pytorch】prototypical network原型网络小样本图像分类简述及其实现_第6张图片

 4、然后拿着这个Embedding表示和类别原型进行相似度计算,也就是无参的解码过程。(相似度计算的方式很多,可以是欧氏距离或者余弦相似度等)

 【Pytorch】prototypical network原型网络小样本图像分类简述及其实现_第7张图片

 5、计算完相似度后,往往还需要使用softmax将相似度激活成概率分布。

最终得到查询集图片的分类标签,然后和真实值标签做交叉熵loss,然后梯度反向传播即可完成一个episode的训练。

Prototypical Network算法描述

1、假设原始数据集为D,对于每一个episode,包含一个支持集和一个查询集,即D_{episode}=D_{support}\cup D_{query}=\left \{ s_i \right \}^{n_s}_{i=1}\cup \left \{ q_i \right \}^{n_q}_{i=1}  。

实现方法就是在原始数据集 D 中随机选取N个类别,每个类别选取K张图片,构成支持集,选取Q张图片,构成查询集,这样就组成了一个episode的小数据集。以此类推,构造 episodes 个小数据集。

2、 对每张图片,利用Encoder进行特征提取,即

 hs_i=f_\varphi (s_i) 

 hq_i=f_\varphi (q_i) 

3、计算出支持集中的每个类别的原型(prototype),即

p_{c_j}=\sum _{\left \{ i|l_{s_i}=c_j \right \}}hs_i

其中,l_{s_i} 表示图片 s_i 的类别标签。

 4、接下来计算每个查询集图片Embedding与每个类别的相似度,即

p(\widehat{l}_{q_i}=c_j)=\frac{exp(sim(hq_i,p_{c_j}))}{\sum ^{|c|}_{k=1}exp(sim(hq_i,p_{c_k}))}

5、训练用的损失函数,公式如下:

L_{q_i} = CrossEntropy(l_{q_i},\widehat{l}_{q_i})

Prototypical Network 的 Pytorch实现

prototypical network是有官方的论文实现的【prototypical-network源码】,而且很多框架里自带原型网络的包,直接调用即可。

但是官方的论文源码比较难看,而且某些场景需要拆解组合时也不方便,因此这里我自己实现了一个精简版的原型网络【我自己的代码复现】。

总结

一般来说way与shot准确率的关系如下所示:

【Pytorch】prototypical network原型网络小样本图像分类简述及其实现_第8张图片

 这个很好理解,一个episode中类别(way)越少,就越容易找出图片之间的异同,比如二分类就比十分类容易一些;一个episode中同一个类别的样本(shot)越多,就越容易找出图片之间的异同。

本质上来说,原型网络就是集成学习中的stacking思想。


题外话

一般来说,对于一个稍微有点机器学习基础的小白,要想在一个未知的ML分支领域快速实现一个比较靠谱的实例(比如打一个小样本学习的算法挑战赛,但我从没接触过小样本学习)。总体步骤如下:

  1. 先用搜索引擎搜索几个相关条目,大体了解这个领域是干什么的;
  2. 找一些该领域的大腿、论坛或者订阅号,如sota排名网页、公众号、添加一下文献鸟的关键词等。最好找个人带
  3. 找几篇这个领域最新的综述,大体了解一下该领域各个流派之间的腥风血雨;
  4. 看几个这个领域比较火的代码实例;
  5. 找几篇最新的SOTA,最好带代码的;
  6. 添加自己的想法,落地实现并验证。

参考资料:

综述:

https://arxiv.org/abs/2004.05439

【干货】如何通过元学习解决小样本图像分类任务 - 知乎

Meta Learning(元学习)_哔哩哔哩_bilibili

SOTA:

Few-Shot Classification Leaderboard

https://arxiv.org/abs/1703.05175v2

TechBeat

简单的实例:

https://arxiv.org/abs/1703.05175

元学习——原型网络(Prototypical Networks)_hei653779919的博客-CSDN博客_原型网络

元学习—MAML模型Pytorch实现_hei653779919的博客-CSDN博客

你可能感兴趣的:(深度学习,Pytorch实现,pytorch,深度学习,分类)