Domain Prompt Learning for Efficiently Adapting CLIP to Unseen Domains

首先介绍一下域泛化 (Domain generalization):从若干个具有不同数据分布的数据集(领域)中学习一个泛化能力强的模型,以便在未知 (Unseen)的域上取得较好的效果(感觉和CLIP的zero-shot很适配)。相比之下,DG比DA更有挑战性。

Abstract

域泛化 (DG) 旨在为不可见域学习可泛化的模型。尽管 ERM(Empirical Risk Minimization)使用标准 DG 基准通过更大的backbone和训练数据集大大提高了准确性,但微调这种 FMs(foundation models )在许多现实场景下并不实用。

因此,作者提出 DPL(Domain Prompt Learning)以条件提示生成( conditional prompt generation)的形式进行领域推理(domain inference)的新方法。并且DPL 仅通过训练轻量级prompt generator(三层 MLP)就实现了显著的精度提升(参数量少)。

并且将DPL与CLIP结合展现了非常好的效果。

Introduction

CLIP显示出在跨视觉任务学习可迁移表示方面的巨大潜力。 其核心是通过将图像表示与图像的文本描述表示进行对比来学习图像表示。文本描述通常称为prompt,其设计对于增强 CLIP 性能至关重要。值得注意的是,CLIP 可以通过使用目标类名称充分更改文本描述来处理看不见的类,而无需对其进行微调。

之前的DG大多关注中等规模的预训练模型。

Domain Prompt Learning for Efficiently Adapting CLIP to Unseen Domains_第1张图片

两种直观的方法:a)微调由 CLIP 训练的image encoder,类似于 ResNet 和 ViT 等其他视觉模型。b)设计prompt模板,例如“a photo of a {class name}”,明显优点是它不需要优化任何网络,因此可以保留预训练学习到的表征。但是其效果不如第一种。

综上,作者提出Domain Prompt Learning (DPL),是CLIP在DG的扩展,一种自然的方法是在prompt template 中添加域特定的特征。因为歧义性与模糊,手动添加比较困难,因此作者建议 DPL 自动生成一个prompt,该prompt在给定每个分布的未标记数据的情况下估计特定领域的特征

具体来说,DPL 使用源域训练一个轻量级prompt generator,它在给定每个分布输入图像的情况下,输出固定长度连续域提示(fixed-length continuous domain prompts),同时冻结其他网络。在测试期间,prompt generator根据目标分布中的输入图像生成domain prompt,并将它们添加到标签提示中。由于整个网络被冻结,预训练的核心属性将保留在 DPL 中。

作者说他们的贡献:

1. 通过prompt learning将 CLIP 引入标准 DG 基准 DomainBed。

2. 提出领域提示学习(DPL),一种新的领域推理方法,通过利用特定域的特征有效地帮助DG。

3. 实验效果好。

Related Work

Test Time Adaptation

本文的方法也可以被看作Test-Time Adaptation,TTA 的概念是更新网络的一部分以最小化预测熵,以便在测试时使模型稳健地适应未知域

Prompt Learning

除了nlp那边,由于 CLIP 的成功应用,prompt turning在计算机视觉中也引起了极大的兴趣。上下文优化 (Context Optimization:CoOp) 表明 CLIP 性能容易受到prompt的影响,并且合适的prompt可以提高图像识别任务的性能。 CLIP-Adapter 提议通过额外的适配器(adapter)网络进行学习。

与以上这些需要访问目标域中的图像或类标签的方法不同,本文通过从输入图像中推断生成域提示 (domain prompt) 来改编 CLIP 以适应一个未见过域(unseen domain)。

Method

先前的方法:先前的工作结合随机初始化的分类头 g(linear classifier)对预训练图像编码器 f(ResNet18 or ResNet50)进行微调,使用来自多个不同数据集的数据来实现目标。也就是用多个不同域(数据集)的数据进行训练。DG 中的不同方法通过设计正则化项使用其他损失函数来防止过拟合到特定域。

Naive Approaches

简单将CLIP应用到DG中的话,分两种:1)zero-shot;2)仅使用CLIP的image encoder并对其进行FT,但是其需要额外的计算成本,而且准确度也不如直接zero-shot。原因:大规模预训练的良好特性可能在FT过程中被扭曲破坏。

CLIP + DPL

给定图像x 和 K 类prompt:,CLIP 使用 输出预测:

其中,K是类别个数, < >是计算余弦相似度。

Designing a prompt 是一个很有效易于训练(因为参数量很小)的方法,我们通过监督损失优化前缀向量(a prefix vector )

其中是:

其中 是可训练参数 的拼接(concatenation)。特别是,无论输入维度(即 的大小)如何, 都会输出固定长度的向量, 的大小是一个超参数。不幸的是,目标域的这种标记训练数据在 DG 中不可用。

DPL 通过训练新颖的prompt生成器 F (·) 来替换每个域中 的优化过程,该prompt生成器 F (·) 根据给定分布的未标记图像的情况下生成提示 。具体来说,我们使用全连接网络 F (·) 从输入图像生成prompt p

N 是每个域的batch size大小, 表示来自第 i 个分布的图像。表示的是分布i的prompt。

给定来自多个源分布的一批数据,我们使用以下损失函数来优化 F:

Domain Prompt Learning for Efficiently Adapting CLIP to Unseen Domains_第2张图片

其中是预定义的 的拼接。是类别prompt。

Architecture

Domain Prompt Learning for Efficiently Adapting CLIP to Unseen Domains_第3张图片

我们只训练了一个prompt generator F (·)。

首先,将输入图像用冻结的 CLIP 图像编码器 f(·) 编码以获得图像特征。将图像特征输入域提示生成器F(·)生成domain prompt。同时,使用冻结的 CLIP 文本编码器 g(·) 对所有labels进行编码,以获得label prompt embeddings

将domain prompt embeddings 和 label prompt embeddings 相加。

计算image embeddings 和 domain prompt embeddings的余弦相似度,获得概率输出。

Future Work

有两种简单而关键的方法可以提高 DG 的性能。一种是将visual prompt tuning应用到纯视觉主干上,这可以用于更多以前的方法。另一个关注以数据为中心的方法,因为我们观察到广泛使用的数据集上的数据质量参差不齐。

最近的几项研究系统地分析了大规模预训练模型在分布外泛化(Out-of-Distribution generalization)中的性能和缺点,我们希望我们的结果能激发更多朝这个方向的研究。

你可能感兴趣的:(计算机视觉)