©作者 | 刘星超
单位 | 德州大学奥斯汀分校
研究方向 | 生成式模型
Diffusion Generative Models(扩散式生成模型)已经在各种生成式建模任务中大放异彩,但是,其复杂的数学推导却常常让大家望而却步,缓慢的生成速度也极大地阻碍了研究的快速迭代和高效部署。研究过 DDPM 的同学可能见到过这种画风的变分法(Variational Inference)推导(截取自 What are Diffusion Models):
总体上推导的难度和对数学的要求还是比较高的。在连续时间的形式下,还需要随机微分方程(Stochastic Differential Equation(SDE))的知识,有不低的入门门槛。除此以外,扩散式生成模型的一个众所周知的老大难问题就是生成速度慢:生成一张图需要模拟一整个基于复杂的深度模型的扩散过程。缓慢的生成速度是阻碍这些模型更广泛的普及的一个主要瓶颈。
Rectified Flow,一个“简简单单走直线”生成模型,是我们对这些挑战的一个回答:极度简单,一步生成。我们的方法有以下要点:
(1)我们无需一般扩散模型复杂的推导,代之以一个简单的“沿直线生成”的思想。算法理解上不需要变分法或随机微分方程等基础知识。我们的方法是基于一个简单的常微分方程(ODE),通过构造一个“尽量走直线”的连续运动系统来产生想要的数据分布。
(2)“尽量走直线”的目的是让我们模型实现快速生成。通过一个叫“reflow”的方法,我们可以实现梦想中的“一步生成”:只需一步计算就直接产生高质量的结果,而不需要调用计算量大的数值求解器来迭代式地模拟整个扩散过程。
(3)通常的扩散模型是把高斯白噪声转换成想要的数据(比如图片)。我们的方法可以把任何一种数据或噪声(比如猫脸照片)转换成另外一种数据(比如人脸照片)。所以我们的方法不仅可以做生成模型,还可以应用于很多更广泛的迁移学习(比如 domain transfer)任务上。
有兴趣的同学可以参见我们的论文(Arxiv 或 OpenReview,以及和最优传输(optimal transport)相关的深入理论 Arxiv)。代码,示例 Colab Notebook 和预训练模型已经开源在 github。一个英文版简介在这里。欢迎大家使用和交流!
▲ Rectified Flow 可以实现生成式模型或者无监督图像转换(图中是人 ↔ 猫)。同时,通过新颖的 Reflow 算法,我们可以将 ODE 的轨迹拉直,在 N=1 时也取得较好的生成效果(图中 N 指我们所使用的 Euler 求解器的步数)。
我们先定义好要解决的问题。无论是从噪声生成图片(generative modeling),还是将人脸转化为猫脸(domain transfer),都可以这样概括成将一个分布转化成另一个分布的问题:
给定从两个分布 和 中的采样,我们希望找到一个传输映射 使得,当 时,。
比如,在生成模型里, 是高斯噪声分布, 是数据的分布(比如图片),我们想找到一个方法,把噪声 映射成一个服从 的数据 。在数据迁移(domain transfer)里, 分别是人脸和猫脸的图片。所以这个问题是生成模型和数据迁移的统一表述。
在我们的框架下,映射 是通过以下连续运动系统,也就是一个常微分方程(ordinary differential equation(ODE)),或者叫流模型(flow),来隐式定义的:
我们可以想象从 里采样出来的 是一个粒子。它从 时刻开始连续运动,在 时刻以 为速度。直到 时刻得到 。我们希望 服从分布 。这里我们假设 是一个神经网络。我们的任务是从数据里学习出 来达到 的目的。
因为 Rectified Flow 要在直线轨迹的交叉点做路径重组,所以上面的 ODE 模型(或者说 flow)的轨迹仍然可能是弯曲的(如上面的图(b)),不能达到一步生成。我们提出一个“Reflow”方法,将 ODE 的轨迹进一步变直。
具体的做法非常简单: 假设我们从 里采样出一批 。然后,从 出发,我们模拟上面学出的 flow(叫它 1-Rectified Flow),得到 。我们用这样得到的 对来学一个新的“2-Rectified Flow”:
这里,2-Rectified Flow 和 1-Rectified Flow 在训练过程中唯一的区别就是数据配对不同:在 1-Rectified Flow 中, 与 是随机或者任意配对的;在 2-Rectified Flow 中, 与 是通过 1-Rectified Flow 配对的。
上面的动图中,图(c)展示了 Reflow 的效果。因为从 1-Rectified Flow 里出来的 已经有很好的配对, 他们的直线插值交叉数减少,所以 2-Rectified Flow 的轨迹也就(比起 1-Rectified Flow)变得很直了(虽然仔细看还不完美)。
理论上,我们可以重复 Reflow 多次,从而得到 3-Rectified Flow, 4-Rectified Flow... 我们可以证明这个过程其实是在单调地减小最优传输理论中的传输代价(transport cost),而且最终收敛到完全直的状态。
当然,实际中,因为每次 优化得不完美,多次 Reflow 会积累误差,所以我们不建议做太多次的Reflow。幸运的是,在我们的实验中,我们发现对生成图片和很多我们感兴趣的问题而言,像上面的图(c)一样,1次 Reflow 已经可以得到非常直的轨迹了,配合蒸馏足够达到一步生成的效果了。
Reflow 解决了 Distillation 的这些困难。它的意义在于 :
1)给定任何 配对,就算是随机的配对,他都能学出一个给出正确边际分布(marginal distribution)的 flow。Reflow 不会去试图完全复现 的配对关系,而只注重于得到正确的边际分布。
2)从 Reflow 出的 ODE 里采样,我们还可以得到一个更好的配对 ,从而给出更好的 flow。重复这个过程可以最终得到保证一步生成的直线 ODE。
形象地来讲,如果 太复杂,Reflow会“拒绝”完全复现 ,转而给出一个新的,更简单的,但仍然满足 的配对 。所以,Distillation 更像“模仿者”,只会机械地模仿,就算问题无解也要“硬做”。Reflow 更像“创造者”,懂得变通,发现新方法来解决问题。
当然,Reflow 和 Distillation 也可以组合使用:先用 Reflow 得到比较好的配对,最后再用已经很好的配对进行 Distillation 。我们在论文里发现,这个结合的策略确实有用。
下面,我们进一步基于具体例子解释一下 Reflow 对配对的提高效果。如果一个配对 是好的,那么从这个配对里随机产生的两条直线 就不会相交。在我们的论文里,这种直线不相交的配对我们叫做“Straight Coupling”。我们的 Reflow 过程就是在不停地降低这个相交概率的过程。下图我们展示随着 Reflow 的不断进行,配对的直线交叉数确实逐渐降低。
在图中,对每种配对方法,我们随机选择两个配对,分别用直线段连接它们,然后若它们相交,就用红色点标出这两条直线段的交点。对于这种交叉的配对,Reflow 就有可能改善它们。
我们重复 10000 次并统计交叉的概率。我们发现:1)每次 Reflow 都降低了交叉的概率和 L2 传输代价;2)即使 2-Rectified Flow 在肉眼观察时已经很直,但它的交叉概率仍不为 0,更多的 Reflow 次数就可能进一步使它变直并降低传输代价。相比之下,单纯的蒸馏是不能改善配对的,这是 Reflow 与蒸馏的本质区别。
▲ 图中,每个红点代表一次两随机的直线交叉的事件。随着 reflow,交叉的概率逐渐降低,对应的 ODE 的轨迹也越来越直。
Rectified Flow 不仅简洁,而且在理论上也有很好的性质。我们在此给出一些理论保证的非正式表述,如果大家对理论部分感兴趣,欢迎大家阅读我们文章的细节。
1.边际分布不变:当 取得最优值时,对任意时间 ,我们有 和 的分布相等。因为 ,因此 确实可以将 转移到 。
2.降低传输损失:每次 Reflow 都可以降低两个分布之间的传输代价。特别的,Reflow 并不优化一个特定的损失函数,而是同时优化所有的凸损失函数。
3.拉直 ODE 轨迹:通过不停重复 Reflow,ODE 轨迹的直线性(Straightness)以 的速率下降,这里, 是 reflow 的次数。
使用 Runge Kutta-45 求解器,1-Rectified Flow 在 CIFAR10 上得到 IS=9.6, FID=2.58,recall=0.57,基本与之前的 VP SDE/sub-VP SDE [2] 相同,但是平均只需要 127 步进行模拟。
Reflow 可以使 ODE 轨迹变直,因此2-Rectified Flow 和 3-Rectified Flow 在仅用一步(N=1)时也可以有效的生成图片(FID=12.21/8.15)。
Reflow 可以降低传输损失,因此在进行蒸馏时会得到更好的表现。用 2-Rectified Flow + 蒸馏,我们在仅用一步生成时得到了 FID=4.85,远超之前最好的仅基于蒸馏/基于 GAN loss 的快速扩散式生成模型(当用一步采样时 FID=8.91)。同时,比起 GAN,Rectified Flow + 蒸馏有更好的多样性(recall>0.5)。
我们的方法也可以用于高清图片生成或无监督图像转换。
▲ 1-rectified flow: 256分辨率图像生成
▲ 1-rectified flow: 256分辨率无监督图像转换
同期相关工作
有意思的是,今年 ICLR 在 openreview 上出现了好几篇投稿论文提出了类似的想法。
(1) Flow Matching for Generative Modeling:
https://openreview.net/forum?id=PqvMRDCJT9t
(2) Building Normalizing Flows with Stochastic Interpolants:
https://openreview.net/forum?id=li7qeBbCR1t
(3) Iterative -alpha (de)Blending: Learning a Deterministic Mapping Between Arbitrary Densities:
https://openreview.net/forum?id=s7gnrEtWSm
(4) Action Matching: A Variational Method for Learning Stochastic Dynamics from Samples:
https://openreview.net/forum?id=T6HPzkhaKeS
这些工作都或多或少地提出了用拟合插值过程来构建生成式 ODE 模型的方法。除此之外,我们的工作还阐明了这个路径相交重组的直观解释和最优传输的内在联系,提出了 Reflow 算法,实现了一步生成,建立了比较完善的理论基础。大家不约而同地在一个地方发力,说明这个方法的出现是有很大的必然性的。因为它的简单形式和很好的效果,相信以后有很大的潜力。
如有任何问题,欢迎留言或者发邮件!
主要论文:
X. Liu, C. Gong, Q. Liu. Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow. ICLR 2023, arXiv:2209.03003
Q. Liu. Rectified flow: A marginal preserving approach to optimal transport. arXiv preprint arXiv:2209.14577, 2022.
参考文献
[1] Song Y, Sohl-Dickstein J, Kingma D P, et al. Score-Based Generative Modeling through Stochastic Differential Equations. International Conference on Learning Representations.
[2] Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 2020, 33: 6840-6851.
[3] Song J, Meng C, Ermon S. Denoising Diffusion Implicit Models. International Conference on Learning Representations.
[4] Lu C, Zhou Y, Bao F, et al. DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps. Advances in Neural Information Processing Systems.
[5] Bansal A, Borgnia E, Chu H M, et al. Cold diffusion: Inverting arbitrary image transforms without noise. arXiv preprint arXiv:2208.09392, 2022.
[6] Liu X, Wu L, Ye M. Learning Diffusion Bridges on Constrained Domains//International Conference on Learning Representations.
[7] Liu Q. Rectified flow: A marginal preserving approach to optimal transport. arXiv preprint arXiv:2209.14577, 2022.
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
投稿通道:
• 投稿邮箱:[email protected]
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
·
·