小样本学习论文阅读 | Confess: A framework for single source cross-domain few-shot learning, ICLR 2022 poster

1. motivation

目前的方法在源域和目标域存在较大域间偏差时实用性较差。本文认为:

1) 无监督学习可以缓解监督崩溃问题,并且训练得到的模型可以更好地推广到目标域中。

2) 因为源数据集和目标数据集之间存在很大差异,因此对源任务有用的特征可能对目标任务没有帮助,甚至有害。所以本文期望在小样本的情况下,通过提取更少的特征来提升泛化性能。

2. contribution

本文提出了一个“对比学习和特征选择系统”(Comparative learning and Feature Selection System)的小样本学习框架Confess,解决了基类和新类之间存在较大域偏移的问题。包含三个部分:

1) 在源域上基于对比损失无监督训练backbone;

2) 引入了mask module在目标域上训练来选择更适合目标域分类的相关特征;

3) 在目标域上微调分类器和backbone 。

实验部分在ECCV2020 challenge benchmark 上取得了很好的效果。

3. 核心内容

3.1 overall framework

小样本学习论文阅读 | Confess: A framework for single source cross-domain few-shot learning, ICLR 2022 poster_第1张图片

3.2 无监督训练backbone

在预训练阶段,使用各种变换从训练批中的现有样本中扩充样本,并使用这些扩充样本和原始样本来计算对比损失。

具体来说,在每个批次中有 个样本,表示为 。对于每个样本 ,都用 个变化得到对应的扩充样本 。让扩充样本 的特征接近原样本 并远离其他样本 ,使用如下交叉熵损失:

具体的变化方式:color distortion (A Simple Framework for Contrastive Learning of Visual Representations, ICML2020)。

3.3 训练mask generator

从源域上训练得到的特征提取器为 。给定目标域上的样本,可以得到每个样本的特征为:

小样本学习论文阅读 | Confess: A framework for single source cross-domain few-shot learning, ICLR 2022 poster_第2张图片

将该特征输入到mask生成模块M中得到对应的mask:

小样本学习论文阅读 | Confess: A framework for single source cross-domain few-shot learning, ICLR 2022 poster_第3张图片

根据得到的mask,可以为每个特征得到对应的positive和negative的特征:

这里希望确保positive feature是有类别区分性的,而negative feature是没有类别区分性的:

其中, 是交叉熵损失, 是两个线性分类器。。

除此之外,最大化正特征集合 和负特征 集合之间的统计距离:

上述损失被联合起来用于训练mask generator:

3.4 微调过程

定义在目标域上的特征提取器为 ,其初始化为 ,对于每个目标域上的样本可以得到其特征为 。将其输入到线性分类器C中,计算交叉熵损失:

除此之外,Positive feature为:

提取到的特征应该和positive feature更接近:

那么微调损失就包含两个部分,分类损失和回归损失:

小样本学习论文阅读 | Confess: A framework for single source cross-domain few-shot learning, ICLR 2022 poster_第4张图片

4. 实验部分

在跨域上得到的效果非常好:

小样本学习论文阅读 | Confess: A framework for single source cross-domain few-shot learning, ICLR 2022 poster_第5张图片

5. 总结

存在一些有疑问的地方,mask generator是基于 在目标域上训练获得的,他是否适用于 提取得到的特征,这里的训练逻辑上有点绕。在迁移学习中,会在源域上预训练得到一个模型,然后再目标域上训练时固定大部分参数,微调剩下的部分参数,这里的mask起到的作用似乎也是类似的,就是把模型迁移到目标域上,但是又在目标域上继续微调了特征提取器。

你可能感兴趣的:(基于度量的元学习,跨域小样本学习,小样本学习,学习,论文阅读,深度学习,计算机视觉)