[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models

文章链接:https://arxiv.org/pdf/2305.01115v1.pdf

代码:

GitHub - Zhendong-Wang/Prompt-Diffusion: Official PyTorch implementation of the paper "In-Context Learning Unlocked for Diffusion Models"

摘要:

我们提出了Prompt Diffusion,这是一个框架,用于在基于扩散的生成模型中实现上下文学习。

给定一对特定于任务的示例图像,例如depth from/to image和scribble from/to image,以及文本指导,我们的模型自动理解底层任务,并根据文本指导在新的查询图像上执行相同的任务。

为了实现这一目标,我们提出了一个vision-language prompt,它可以对广泛的视觉语言任务进行建模。扩散模型使用这些prompt在六个不同的任务上进行联合训练。由此产生的提示扩散模型是第一个能够在语境中学习的基于扩散的视觉语言基础模型。它在训练任务上展示了高质量的上下文生成,并有效地推广到具有各自提示的新的、未见过的视觉任务。

介绍:

通过适当设计prompt结构和in-context学习,LLMs(large language models)可以将多种语言任务的预训练结合起来,并很好地推广到以前未见过的任务。虽然语境学习在自然语言处理中得到了广泛的研究,但其在计算机视觉领域的应用仍然有限。本文旨在解锁文本引导的基于扩散的生成模型的上下文学习能力。我们引入了一个新的模型架构,提示扩散,在视觉语言提示下执行上下文学习,可以适应各种各样的视觉语言任务。我们在六种不同的视觉语言任务上共同训练提示扩散。

我们首先使用我们的视觉语言提示定义一个通用的视觉语言任务:

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第1张图片

其中(image1→image2)由一对视觉任务示例组成,例如(depth map→image)。文本指导为特定任务提供语言指令,image3是输入图像查询,它在类型上与image1对齐,因此可以是真实图像或图像条件(例如,depth map 或者 hed map)。然后,我们建立了提示扩散,灵感来自Stable Diffusion的设计和ControlNet,它可以将我们的视觉语言提示符作为输入。Prompt Diffusion成功地将六种不同任务的学习整合到一个视觉语言基础模型中。通过基于提示的学习,该模型可以有效地掌握输入样例对之间的底层关系。

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第2张图片

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第3张图片

3.方法

我们通过 成对的图像示例 和 图像查询 和 文本输入来设计视觉语言prompt。通过这种设计,我们提出了一种新的输入输出对格式,可以推广大多数视觉语言任务的输入输出配置。我们用一个例子来说明它:

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第4张图片

其中example:(image1→image2)告知模型目标任务是什么以及如何通过示例来完成它,例如,从head map生成图像。text-guidance引导模型在给定文本上生成图像条件,image-query: image3与image1在类型上对齐,表示特定任务的输入。example pair可以适应任何图像域转换任务,例如: forward tasks: image → segmentation/depth/hed mapd 和inverse tasks: segmentation/depth/hed map → image,而text-guidance的参与为生成目标图像提供了additional context

对于图像输入,我们在通道维度上拼接示例图像对,然后将拼接后的example pairimage query通过各自独立的卷积层投影到相同的维度嵌入中。我们计算两个嵌入的和,并将其输入ControlNet分支。对于文本输入,我们遵循Stable Diffusion[41]的惯例,通过预训练的CLIP[37]文本编码器对文本输入进行编码。由此产生的CLIP文本嵌入然后通过交叉注意层馈送到稳定扩散分支。

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第5张图片

我们保持Stable Diffusion分支和ControlNet分支与其原始版本相同,这样我们就可以从现有的checkpoint微调Prompt Diffusion,而不需要从头开始训练它。我们从Stable Diffusion ‘v1.5’ checkpoint微调Prompt Diffusion

In-Context Learning:

我们在联合训练中考虑了六种不同的视觉任务。当任务输入是干净的图像时,我们称之为正向任务,当任务输入是图像条件时,我们称之为逆任务,例如segmentation/dep/hed maps,如图2所示。

数据集:

我们使用Uniformer获得图像分割映射[21]。我们使用canny边缘检测器[5]收集canny边缘map,使用hed边界检测器[60]收集hed map。

Inverse Tasks:

对于数据集中的任何image-text对,(I1, C1),我们首先采样一个随机任务和另一个随机图像I2,以创建示例对,例如(HED(I2), I2)。然后,我们使用示例对、标题和与示例对指定的任务一致的图像条件构建prompt。图像I1是Prompt扩散去噪的目标。

一个完整的逆任务示例如下:

通过将hed map替换为另外两个图像条件,我们可以获得所有三个逆任务的the vision-language prompts

Forward Tasks.

我们还考虑了三个前向任务(图像处理任务): : images to depth maps, images to hed maps, 和images to segmentation maps。我们遵循相同的逆任务规则来创建the vision-language prompts注意,这里的任务是颠倒的,因此我们颠倒了示例对和查询目标对的顺序。图像的text对于图像处理任务来说不是必需的,所以我们为每个任务使用特定于任务的固定文本标签,例如“head maps”。我们展示了一个完整的前向任务示例如下:

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第6张图片

Joint Training

我们在这六种不同的视觉语言任务上共同训练提示扩散。为了简单起见,我们在这六个任务上均匀随机地训练我们的模型。具体来说,每个小批数据包含随机抽样的视觉语言提示和随机选择任务的相应目标图像。在这些不同任务上的联合训练解锁了扩散模型的上下文学习能力,如图4和图5所示。应用classifier-free guidance (CFG)我们在训练过程中随机放弃10%的文本指导。

评估方法:

鉴于评估情境学习能力的定量指标有限,我们专注于提供大量的定性评估(即与其他方法的生成图比较)。

ControlNet的比较

我们定性地比较了Prompt DiffusionControlNet[65]我们遵循ControlNet的指导

,从相同的stable diffusion checkpointsControlNet进行微调,在我们的数据集上为每个反向任务独立地作为我们的基线,我们称之为CN(FT)。我们在图9中显示了比较结果。我们观察到联合训练的Prompt Diffusion与独立训练的ControlNet相比表现相当好,这意味着应用多任务学习没有显著的性能变化。我们在图10中进一步评估了特定于任务的ControlNet的生成能力。在推理过程中,我们直接将独立微调的ControlNet应用于一个新的任务。当直接将ControlNet应用于新的、不可见的任务时,我们可以观察到大量的失败。

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第7张图片

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第8张图片

模型体系结构

注意,我们锁定了Stable Diffusion编码器块的参数进行微调,以继承Stable Diffusion从大规模预训练中学习到的编码器能力。所有其他参数都可以进行微调。Stable Diffusion分支以潜伏空间中的图像作为输入,ControlNet分支以原始图像作为输入。将图像映射到latent space并将latent image重新映射到原始图像的预训练图像编码器和解码器未在插图中示出。latent image有四个通道,高度和宽度都缩小了八分之一。

正如在主要论文中提到的,我们的ControlNet分支将三个图像作为输入。其中两个形成示例对,首先在RGB通道中拼接,然后通过堆叠的卷积层进行编码。第三张图像表示图像查询,并使用独立的堆叠卷积层进行编码。我们将示例对和图像查询编码为相同维度的潜在嵌入,然后将它们相加作为ControlNet分支的输入。

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第9张图片

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第10张图片

More exemples:

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第11张图片

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第12张图片

[Prompt Diffusion]In-Context Learning Unlocked for Diffusion Models_第13张图片

你可能感兴趣的:(深度学习,人工智能)