Cross Attention Network for Few-shot Classification

作者:一颗柠檬味的橙子
链接:https://zhuanlan.zhihu.com/p/105717426
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
 

来源:NeurIPS 2019
文章题目:Cross Attention Network for Few-shot Classification
下载地址: https://arxiv.org/abs/1910.07677

本文的主要工作是研究小样本学习(Few-shot Learning)在图像分类中的应用。小样本学习当前比较流行的解决思路就是元学习(meta-learning),元学习从一组任务中训练元学习器,提取元知识并将其转化到新任务中。总的来说元学习的方法分为三个类型:

(1)Optimization-based methods:将元学习器设计为一个学习更新模型参数的优化器,从而是其他模型学习到一个较好的初始值,从而尽快适应新任务。

(2)Parameter-generating based methods:通常将学习器设计为参数预测网络。

(3)Metric-learning based methods:学习一个公共特征空间,根据距离度量进行分类 。

文中作者使用的是Metric-learning based method。不同于传统方法:文中首先独立提取支持集和查询集的样本特征,利用支持集和查询集特性之间的语义相关性来突出显示目标对象。此外传统的注意力模型(例如SENet)只是基于训练类的先验来定位测试图像的重要区域,而不能推广到未知类的测试图像。因此本文中设计了一个元学习器来计算支持集和查询集特征图之间的交叉注意力图,这有助于定位目标对象的重要区域并增强特征的可识别性。

Cross Attention Module

本文中的小样本分类任务包含了训练集(包含了大量的标签和类别)、【支持集(包含了少量标签和类别,且与训练集不相交)和查询集(无标签信息,与支持集在同一标签空间)】。

Cross Attention Network for Few-shot Classification_第1张图片

图1 Cross Attention Module

图中绿色表示支持集的特征,蓝色表示查询集的特征。如图本文设计一个Correlation layer去计算支持集和查询集之间的关联。其计算方式如下:

上述 (支持集)表示局部类别特征向量和所有查询特征向量之间的关系, (查询集)表示局部查询特征向量和所有类别特征向量之间的关系。

图1(b),使用Meta fusion layer根据相应的相关映射分别生成类和查询注意力映射。Meta fusion layer使用一个 核为 ( )的卷积操作,本文的加权聚合应该将注意力吸引到目标对象上,而不是简单地突出显示支持集和查询集之间在视觉上相似的区域。

基于上述分析,作者设计了一个元学习器,根据类别特征和查询特征之间的相关性自适应地生成核。元学习的函数表示如下,其中GAP表示为全局平均池化,

Cross Attention Network

Cross Attention Network for Few-shot Classification_第2张图片

图2 Cross Attention Network

如图2所示,Cross Attention Network(CAN)主要包括一个Embedding操作和Cross Attention Module,Embedding主要是用于图像特征提取,Cross Attention Module如图1所示。CAN最后通过一个局部分类器和一个全局分类器组成。局部分类器通过支持集特征和查询集特征之间的余弦距离,计算两个特征之间的相似度从而得到查询集特征的概率值。全局分类器通过一个全连接层之后直接通过Softmax进行分类。模型优化过程中通过叠加局部分类器的损失和全局分类器的损失得到最终的损失函数:

Cross Attention Network for Few-shot Classification_第3张图片

blue-blue272/fewshot-CAN

你可能感兴趣的:(paper,reading)