最近非常不巧的要研究什么diffusion…然而目前网上能找到的资料完全是设计给非常熟练数学的人看的(哪怕对于许多所谓的"入门教程",基本就是纯数学劝退教程),对于我这种高数概率论约等于挂科的人来说根本没法看。因此希望写一篇尽量通俗易懂,在尽量避免「概率论」的情况下,能把diffusion讲明白来的文章。
由于笔者数学并不是很好,且也只是刚刚接触diffusion模型,因此本文应「只」适合于同样「数学较差」无法看懂网络上其他地方(例如X乎)教程的同学,「不」适合对diffusion有关底层数学原理动机比较熟悉的。如果存在推理描述错误,以及对本文表述有疑问之类,欢迎一同在评论区中讨论。
© c s d n : x i o n g x y o w o \copyright csdn: xiongxyowo c◯csdn:xiongxyowo
如果后续推导中有不理解的数学定义,「回到」这里或许能找到解释。
y ∝ x y \propto x y∝x, y y y正比于 x x x,即 y y y随着 x x x增大而线性增大。
P ( A ∣ B ) P(A \mid B) P(A∣B)表示事件 B B B已经发生的情况下,事件 A A A发生的可能性。
换在本文的语境下,就是变量 B B B已知的情况下,变量 A A A的取值分布。
P ( A ∣ B , C ) P(A \mid B, C) P(A∣B,C)则是一种多元条件概率,表示在 B B B, C C C同时发生的情况下, A A A发生的概率。
换在本文的语境下,就是变量 B B B, C C C已知确定的情况下,变量 A A A的取值分布。
P ( A ∣ B ) = P ( B ∣ A ) ∗ P ( A ) P ( B ) P(A \mid B) = \frac{P(B \mid A) * P(A)}{P(B)} P(A∣B)=P(B)P(B∣A)∗P(A)
给定均值为 μ \mu μ,标准差为 σ \sigma σ,方差为 σ 2 \sigma^2 σ2的高斯分布 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N(μ,σ2),其概率密度函数为: p ( x ) = 1 2 π σ e − 1 2 ( x − μ σ ) 2 p(x) = \frac{1}{{\sqrt {2\pi } \sigma }}{e^{ - \frac{1}{2}{{(\frac{{x - \mu }}{\sigma })}^2}}} p(x)=2πσ1e−21(σx−μ)2。
很多时候,为了方便起见,也会写成 p ( x ) ∝ e − 1 2 ( x − μ σ ) 2 p(x) \propto {e^{ - \frac{1}{2}{{(\frac{{x - \mu }}{\sigma })}^2}}} p(x)∝e−21(σx−μ)2,也就是把前面乘的常数系数 1 2 π σ \frac{1}{{\sqrt {2\pi } \sigma }} 2πσ1去掉了。
进一步的,为了推导方便起见,我们把 exp ( − 1 2 ( x − μ σ ) 2 ) {\exp({ - \frac{1}{2}{{(\frac{{x - \mu }}{\sigma })}^2}})} exp(−21(σx−μ)2)展开,因此有 p ( x ) ∝ exp ( − 1 2 ( 1 σ 2 x 2 − 2 μ σ 2 x + μ 2 σ 2 ) ) p(x) \propto \exp(-\frac{1}{2}\left(\frac{1}{\sigma^2} x^2-\frac{2 \mu}{\sigma^2} x+\frac{\mu^2}{\sigma^2}\right)) p(x)∝exp(−21(σ21x2−σ22μx+σ2μ2))
如果对形如 q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_t ; \sqrt{1-\beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}\right) q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)的式子感到疑惑,这篇文章提供了一种理解方法。
以上面提到的高斯分布为例,如果概率分布密度函数中的自变量不是默认的 x x x而是其他,那么应该在分布记号中显式的用分号表示实际的自变量。比如实际的自变量是 x 1 x_1 x1而非 x x x,那么高斯分布应记做 N ( x 1 ; μ , σ 2 ) \mathcal{N}(x_1; \mu, \sigma^2) N(x1;μ,σ2)。默认不写分号的话, N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N(μ,σ2)等价于 N ( x ; μ , σ 2 ) \mathcal{N}(x; \mu, \sigma^2) N(x;μ,σ2)。
对(标准)高斯分布 N ( 0 , 1 ) \mathcal{N}(0, 1) N(0,1)做乘法,乘以 σ \sigma σ,得到一个新的高斯分布, N ( 0 , σ 2 ) \mathcal{N}(0, \sigma^2) N(0,σ2)。
对(标准)高斯分布 N ( 0 , 1 ) \mathcal{N}(0, 1) N(0,1)做加法,加上 μ \mu μ,得到一个新的高斯分布, N ( μ , 1 ) \mathcal{N}(\mu, 1) N(μ,1)。
两个高斯分布 N ( 0 , σ 1 2 ) \mathcal{N}(0, \sigma_1^2) N(0,σ12), N ( 0 , σ 2 2 ) \mathcal{N}(0, \sigma_2^2) N(0,σ22)相加,得到一个新的高斯分布, N ( 0 , σ 1 2 + σ 2 2 ) \mathcal{N}(0, \sigma_1^2 + \sigma_2^2) N(0,σ12+σ22)。
对高斯分布 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N(μ,σ2)进行采样一个噪声 ϵ \epsilon ϵ,等价于先从标准高斯分布 N ( 0 , 1 ) \mathcal{N}(0, 1) N(0,1)中采样得到一个噪声 z \mathbf{z} z,乘以均值 μ \mu μ,加上标准差 σ \sigma σ,即: ϵ = μ + z ⋅ σ \epsilon = \mu + \mathbf{z} \cdot \sigma ϵ=μ+z⋅σ。进行这一转化是为了方便网络训练。
什么是Denoising Diffusion Probabilistic Models,去噪扩散概率模型呢?
还是按照最小白的理解,diffusion其实也是一个深度网络(比方说Attention UNet),其输入是一个"噪声",而输出则是一种我们想要的有意义的数据(比如看起来十分真实的图像)。从这个角度看,diffusion很像另一种常见的生成模型,GAN。一般来讲,对于(非条件)GAN,同样是输入一个噪声,然后得到一个我们想要的东西。
那么diffusion模型是怎么工作的呢?对于GAN,我们知道会有个判别器,通过对抗训练的方式,让生成器逐渐学会将输入的噪声转化为有价值的信息。而diffusion的思想可以理解如下:
对于随机采样得到的一个标准高斯噪声,我们认为,其并不完全是噪声,里面其实还是有特别的信息的。比如说,我们随便找一张图像,往其中不断加入标准高斯噪声,最终图像会被噪声给淹没,但是原始的信息可以认为仍保留在最终得到的噪声中。那么现在,如果我们能设计一个神经网络,去学习「去噪」(denoising),完成上述加噪声过程的逆过程(去噪声);那么,对于任意采样得到的一个新标准高斯噪声,我们就可以通过去噪过程来恢复出其中原本存在的有价值信息(比方说看起来十分逼真的图像)。
正式一点的说法对应着下图:
diffusion模型包含两个过程:
接下来具体来看下前向过程和逆向过程。
前向过程也称为扩散过程,将真实数据逐步变成噪声。
比方说,给定一张原始图像 x 0 \mathbf{x}_0 x0,我们对其加一次「标准」高斯噪声 z ∼ N ( 0 , I ) \mathbf{z} \sim \mathcal{N}(0, \mathbf{I}) z∼N(0,I),得到 x 1 \mathbf{x}_1 x1。记 x i \mathbf{x}_i xi为对原始图像加 i i i次噪声后的结果,可以发现,当 i i i足够大的时候,数据会被高斯噪声淹没,变成纯正的高斯噪声。
现在就涉及到了第一个问题,加多少次噪声?在文中,其由一个超参数 T T T控制,即步数。原文 T = 1000 T=1000 T=1000,即对原始图像加1000次噪声后,其会变成完全的高斯噪声。
接下来是第二个问题,噪声怎么加?因为加噪过程本质是加权和,比如 0.8 × I m a g e + 0.1 × N o i s e 0.8×Image + 0.1×Noise 0.8×Image+0.1×Noise,会涉及到一个权重的问题(注意,我们后面会看到,图像的权重与噪声的权重相加并不需要为1)。在文章中,噪声的这个权重有个专有的名词,叫做扩散率,记为 β \beta β,比如可以从 0.0001 0.0001 0.0001逐步插值到 0.02 0.02 0.02。从这里可以看到,加噪是一个逐步的过程,对图像原有的信息是慢慢破坏的(扩散率很低)。这样主要是为了方便网络在逆扩散过程中学习去噪,如果对信息一次破坏太多那么网络可能就无法学会怎么去复原了。
而为什么扩散率是逐渐增大的呢?其实可以反过来理解,在加噪声的过程中,扩散率逐渐增大,对应着在去噪声的过程中,扩散率逐渐减小——也就是说,去噪的过程是先把"明显"的噪声给去除,对应着较大的扩散率;当去到一定程度,逐渐逼近真实真实图像的时候,去噪速率逐渐减慢,开始微调,也就是对应着较小的扩散率。
解决了这两个问题后,我们就可以来看扩散过程的初步数学定义了。给定当前具有一定噪声的图像 x t − 1 \mathbf{x}_{t-1} xt−1,加入标准高斯噪声噪声 z t − 1 ∼ N ( 0 , I ) \mathbf{z}_{t-1} \sim \mathcal{N}(0, \mathbf{I}) zt−1∼N(0,I),得到进一步加噪的图像 x t \mathbf{x}_t xt,有:
> 重要公式 1 < x t = 1 − β t x t − 1 + β t z t − 1 \mathbf{x}_t=\sqrt{1-\beta_t} \mathbf{x}_{t-1}+\sqrt{\beta_t} \mathbf{z}_{t-1} xt=1−βtxt−1+βtzt−1 > 重要公式 1 <
这个东西其实就是上面我们提到的 a × I m a g e + b × N o i s e a×Image + b×Noise a×Image+b×Noise,其中 I m a g e Image Image为 x t − 1 \mathbf{x}_{t - 1} xt−1, N o i s e Noise Noise为 z t − 1 \mathbf{z}_{t-1} zt−1。
其实有了上面这个式子,对于编程实现来说就已经足够了…不过大多数文章非常喜欢提下面这个式子,也就是概率分布的形式: q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_t ; \sqrt{1-\beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}\right) q(xt∣xt−1)=N(xt;1−βtxt−1,βtI) 需要注意的是 x t = 1 − β t x t − 1 + β t z t − 1 \mathbf{x}_t=\sqrt{1-\beta_t} \mathbf{x}_{t-1}+\sqrt{\beta_t} \mathbf{z}_{t-1} xt=1−βtxt−1+βtzt−1和 q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_t ; \sqrt{1-\beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}\right) q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)是等价的。具体来说,参考前置知识中的「重参数化技巧」,我们知道 ϵ = μ + z ⋅ σ \epsilon = \mu + \mathbf{z} \cdot \sigma ϵ=μ+z⋅σ表述的就是从 ϵ ∼ N ( μ , σ 2 ) \epsilon \sim \mathcal{N}(\mu, \sigma^2) ϵ∼N(μ,σ2)中采样的过程。据此,同样就可以将 x t \mathbf{x}_t xt改写为从 q ( x t ∣ x t − 1 ) q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) q(xt∣xt−1)中采样的形式。
当然,如果进行一些更人话的理解,则有:
现在来解决本章的最后一个问题。给定原始图像 x 0 \mathbf{x}_0 x0,能不能一步计算得到加噪任意 t t t次后的 x t \mathbf{x}_t xt?答案是可以的,这里首先直接给出结论:
> 重要公式 2 < x t = α ˉ t x 0 + 1 − α ˉ t z ~ t \mathbf{x}_{t} = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\tilde{\mathbf{z}}_t xt=αˉtx0+1−αˉtz~t > 重要公式 2 <
其中 α t = 1 − β t \alpha_t = 1 - \beta_t αt=1−βt, α t ˉ = α 1 × . . . × α t = ∏ i = 1 t α i \bar{\alpha_t} = \alpha_1 × ... × \alpha_t = \prod \limits_{i=1}^t \alpha_i αtˉ=α1×...×αt=i=1∏tαi, z ~ t ∼ N ( 0 , I ) \tilde{\mathbf{z}}_t \sim \mathcal{N}(0, \mathbf{I}) z~t∼N(0,I)。这样,当我们想求一个 t t t很大的 x t \mathbf{x_t} xt时,就省去了逐步模拟的麻烦。从这里可以发现,当 t t t很大时, α ˉ t \sqrt{\bar{\alpha}_t} αˉt会很接近 0 0 0,最终的结果 x t \mathbf{x}_t xt几乎完全由噪声 z ~ t \tilde{\mathbf{z}}_t z~t所取代,但仍然保留了十分微弱的原始图像 x 0 \mathbf{x}_0 x0。也就是说,只要方法巧妙,理论上还是可以通过逐步去噪来把 x t \mathbf{x}_t xt中隐藏的 x 0 \mathbf{x}_0 x0给搞到手的。
关于 x t = α ˉ t x 0 + 1 − α ˉ t z ~ t \mathbf{x}_{t} = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\tilde{\mathbf{z}}_t xt=αˉtx0+1−αˉtz~t的推导,其实只要老老实实从 x t \mathbf{x}_{t} xt变到 x t − 1 \mathbf{x}_{t-1} xt−1, x t − 2 \mathbf{x}_{t-2} xt−2…一直反复代入进去即可,相对较为简单。对推导不感兴趣的以下部分可以跳过,不影响对其他部分的理解。
> 推导开始 <
x t = α t x t − 1 + 1 − α t z t − 1 = α t α t − 1 x t − 2 + α t ( 1 − α t − 1 ) z t − 2 + 1 − α t z t − 1 = α t α t − 1 x t − 2 + 1 − α t α t − 1 z ˉ t − 2 = … = α ˉ t x 0 + 1 − α ˉ t z ~ t \begin{aligned} \mathbf{x}_t &=\sqrt{\alpha_t} \mathbf{x}_{t-1}+\sqrt{1-\alpha_t} \mathbf{z}_{t-1} \\ &=\sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{\alpha_t\left(1-\alpha_{t-1}\right)} \mathbf{z}_{t-2}+\sqrt{1-\alpha_t} \mathbf{z}_{t-1} \\ &=\sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}} \bar{\mathbf{z}}_{t-2} \\ &=\ldots \\ &=\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \tilde{\mathbf{z}}_t \end{aligned} xt=αtxt−1+1−αtzt−1=αtαt−1xt−2+αt(1−αt−1)zt−2+1−αtzt−1=αtαt−1xt−2+1−αtαt−1zˉt−2=…=αˉtx0+1−αˉtz~t 上面式子中各种 z \mathbf{z} z的变体都满足标准高斯分布 N ( 0 , 1 ) \mathcal{N}(0, 1) N(0,1)。第一行到第二行就是把 x t − 1 = α t − 1 x t − 2 + 1 − α t − 1 z t − 2 \mathbf{x}_{t-1} =\sqrt{\alpha_{t-1}} \mathbf{x}_{t-2}+\sqrt{1-\alpha_{t-1}} \mathbf{z}_{t-2} xt−1=αt−1xt−2+1−αt−1zt−2给换进去。主要比较难理解的地方在于第二行到第三行, α t ( 1 − α t − 1 ) z t − 2 + 1 − α t z t − 1 \sqrt{\alpha_t\left(1-\alpha_{t-1}\right)} \mathbf{z}_{t-2}+\sqrt{1-\alpha_t} \mathbf{z}_{t-1} αt(1−αt−1)zt−2+1−αtzt−1是怎么变成 1 − α t α t − 1 z ˉ t − 2 \sqrt{1-\alpha_t \alpha_{t-1}}\bar{\mathbf{z}}_{t-2} 1−αtαt−1zˉt−2的。
具体来说, α t ( 1 − α t − 1 ) z t − 2 \sqrt{\alpha_t\left(1-\alpha_{t-1}\right)} \mathbf{z}_{t-2} αt(1−αt−1)zt−2其实就是 N ( 0 , α t ( 1 − α t − 1 ) ) \mathcal{N}(0, \alpha_t(1-\alpha_{t-1})) N(0,αt(1−αt−1)), 1 − α t z t − 1 \sqrt{1-\alpha_t} \mathbf{z}_{t-1} 1−αtzt−1其实就是 N ( 0 , 1 − α t ) \mathcal{N}(0, 1-\alpha_t) N(0,1−αt),两者相加,得到 N ( 0 , 1 − α t α t − 1 ) \mathcal{N}(0, 1-\alpha_{t}\alpha_{t-1}) N(0,1−αtαt−1),也就是 1 − α t α t − 1 z ˉ t − 2 \sqrt{1-\alpha_t \alpha_{t-1}} \bar{\mathbf{z}}_{t-2} 1−αtαt−1zˉt−2。这里不同形式的 z \mathbf{z} z单纯是起「区分」作用,本质上同属于一个分布 N ( 0 , I ) \mathcal{N}(0, \mathbf{I}) N(0,I)下的「不同」采样。
> 推导结束 <
> 本章小结 <
总结一下,扩散过程就是给定原始图像 x 0 \mathbf{x}_0 x0,获取其加入不同次噪声后的结果 x t \mathbf{x}_t xt的过程。这些 x t \mathbf{x}_t xt将作为标签,帮助网络学会如何从纯噪声 x T \mathbf{x}_T xT中一步一步去噪,最终恢复出真实图像 x 0 \mathbf{x}_0 x0。
扩散过程是从原始数据 x 0 \mathbf{x}_0 x0逐渐加噪声变成 x T \mathbf{x}_T xT。所谓逆扩散过程,也就是从 x T \mathbf{x}_T xT逐步给回到 x 0 \mathbf{x}_0 x0,即求: q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_t}) q(xt−1∣xt) 也就是说,现在我们知道了带噪声的数据 x t \mathbf{x}_t xt,想要知道其去掉一次噪声后的 x t − 1 \mathbf{x}_{t-1} xt−1是什么样的。去得噪声足够多,最后没有噪声,自然就回到了我们的原始数据 x 0 \mathbf{x}_0 x0。
那 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_t}) q(xt−1∣xt)怎么求呢?可以发现,加噪过程 q ( x t ∣ x t − 1 ) q(\mathbf{x}_{t} \mid {\mathbf{x}_{t-1}}) q(xt∣xt−1)我们是知道的,因此利用贝叶斯公式的思想,有: q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) × q ( x t − 1 ) q ( x t ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}}) = \frac{q(\mathbf{x}_{t} \mid {\mathbf{x}_{t-1}})×q(\mathbf{x}_{t-1})}{q({\mathbf{x}_t})} q(xt−1∣xt)=q(xt)q(xt∣xt−1)×q(xt−1) 现在就出现了一个问题,虽然 q ( x t ∣ x t − 1 ) q(\mathbf{x}_{t} \mid {\mathbf{x}_{t-1}}) q(xt∣xt−1)我们是知道了,但是 q ( x t ) q(\mathbf{x}_{t}) q(xt)和 q ( x t − 1 ) q(\mathbf{x}_{t-1}) q(xt−1)我们不知道。这里需要特别注意的是,当 T T T足够大的时候,可以认为 q ( x T ) q(\mathbf{x}_T) q(xT)就是标准高斯噪声,这个我们是可以知道的;而由于 t t t我们并不知道是多少,可能是个很小的值,这种情况下 q ( x t ) q(\mathbf{x}_t) q(xt)中包含了大量的原始图像信息,因此 q ( x t ) q(\mathbf{x}_t) q(xt)我们是不知道的。
要想知道加了一定噪声的图像 q ( x t ) q(\mathbf{x}_t) q(xt)和 q ( x t − 1 ) q(\mathbf{x}_{t-1}) q(xt−1),自然就依赖于一个先决条件,没加噪声的图像 q ( x 0 ) q(\mathbf{x_0}) q(x0)。换句话说, q ( x t ∣ x 0 ) q(\mathbf{x}_t \mid \mathbf{x_0}) q(xt∣x0)和 q ( x t − 1 ∣ x 0 ) q(\mathbf{x}_{t-1} \mid \mathbf{x_0}) q(xt−1∣x0)我们是知道的,因此对式子 q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) × q ( x t − 1 ) q ( x t ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}}) = \frac{q(\mathbf{x}_{t} \mid {\mathbf{x}_{t-1}})×q(\mathbf{x}_{t-1})}{q({\mathbf{x}_t})} q(xt−1∣xt)=q(xt)q(xt∣xt−1)×q(xt−1)再加上一个条件 x 0 \mathbf{x_0} x0,得到一个多元条件分布,有: q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) × q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}, \mathbf{x}_0}) = \frac{q(\mathbf{x}_{t} \mid {\mathbf{x}_{t-1} , \mathbf{x}_0})×q(\mathbf{x}_{t-1} \mid \mathbf{x}_0)}{q({\mathbf{x}_t} \mid \mathbf{x}_0)} q(xt−1∣xt,x0)=q(xt∣x0)q(xt∣xt−1,x0)×q(xt−1∣x0) 其实上面这个式子还可以继续变一下。由于扩散过程是一个马尔可夫过程,因此 x t \mathbf{x}_t xt只和 x t − 1 \mathbf{x}_{t-1} xt−1有关,和 x 0 \mathbf{x}_0 x0无关,即 q ( x t ∣ x t − 1 , x 0 ) = q ( x t ∣ x t − 1 ) q(\mathbf{x}_{t} \mid {\mathbf{x}_{t-1} , \mathbf{x}_0}) = q(\mathbf{x}_{t} \mid {\mathbf{x}_{t-1}}) q(xt∣xt−1,x0)=q(xt∣xt−1),有: q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 ) × q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}, \mathbf{x}_0}) = \frac{q(\mathbf{x}_{t} \mid {\mathbf{x}_{t-1} })×q(\mathbf{x}_{t-1} \mid \mathbf{x}_0)}{q({\mathbf{x}_t} \mid \mathbf{x}_0)} q(xt−1∣xt,x0)=q(xt∣x0)q(xt∣xt−1)×q(xt−1∣x0) 其实细心的读者可以发现一个问题,在测试阶段, x 0 \mathbf{x}_0 x0本身是我们要求的东西,是未知的;因此上面这个式子只有在训练阶段 x 0 \mathbf{x}_0 x0已知的情况下才能运行起来。为了让测试阶段也能用,我们对上面这个式子进行进一步的分析,「看看能不能把 x 0 \mathbf{x}_0 x0给消除掉」。如果能消除,就不用陷入这种要算 x 0 \mathbf{x}_0 x0必须知道 x 0 \mathbf{x}_0 x0的套娃情况了。根据上一章的重要公式1,2:
这里为什么要把概率密度函数的形式给拿出来呢?其实是方便运算。这里先给出一个简单的结论,两个分布相乘,可以认为就是对其密度函数相加;两个分布相除,可以认为就是对其密度函数相减。因此, q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 ) × q ( x t − 1 ∣ x 0 ) / q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}, \mathbf{x}_0}) = q(\mathbf{x}_{t} \mid {\mathbf{x}_{t-1} })×q(\mathbf{x}_{t-1} \mid \mathbf{x}_0) / q({\mathbf{x}_t} \mid \mathbf{x}_0) q(xt−1∣xt,x0)=q(xt∣xt−1)×q(xt−1∣x0)/q(xt∣x0),写成密度函数的形式,有: q ( x t − 1 ∣ x t , x 0 ) ∝ exp ( − 1 2 [ ( x t − α t x t − 1 ) 2 β t + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ] ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}, \mathbf{x}_0}) \propto \exp(-\frac{1}{2} [\frac{(\mathbf{x}_t - \sqrt{{\alpha}_{t}}\mathbf{x}_{t-1})^2}{{\beta}_{t}} + \frac{(\mathbf{x}_{t-1} - \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0)^2}{1 - \bar{\alpha}_{t-1}} - \frac{(\mathbf{x}_{t} - \sqrt{\bar{\alpha}_t}\mathbf{x}_0)^2}{1 - \bar{\alpha}_t}]) q(xt−1∣xt,x0)∝exp(−21[βt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2]) 现在,我们要对上面这个式子进行进一步的整理,看看能不能搞出什么有用的东西来。那么就先把括号里的平方展开来试一试: q ( x t − 1 ∣ x t , x 0 ) ∝ exp ( − 1 2 [ x t 2 − 2 α t x t x t − 1 + α t x t − 1 2 β t + x t − 1 2 − 2 α ˉ t − 1 x 0 x t − 1 + α ˉ t − 1 x 0 2 1 − α ˉ t − 1 − x t 2 − 2 α ˉ t x 0 x t + α ˉ t x 0 2 1 − α ˉ t ] ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}, \mathbf{x}_0}) \propto \exp(-\frac{1}{2} [\frac{\mathbf{x}_t^2 - 2\sqrt{{\alpha}_{t}}\mathbf{x}_t\mathbf{x}_{t-1} +{{\alpha}_{t}}\mathbf{x}_{t-1}^2}{{\beta}_{t}} + \frac{\mathbf{x}_{t-1}^2 - 2\sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0\mathbf{x}_{t-1} + \bar{\alpha}_{t-1}\mathbf{x}_0^2}{1 - \bar{\alpha}_{t-1}} - \frac{\mathbf{x}_{t}^2 -2\sqrt{\bar{\alpha}_t}\mathbf{x}_0\mathbf{x}_{t} + \bar{\alpha}_t\mathbf{x}_0^2}{1 - \bar{\alpha}_t}]) q(xt−1∣xt,x0)∝exp(−21[βtxt2−2αtxtxt−1+αtxt−12+1−αˉt−1xt−12−2αˉt−1x0xt−1+αˉt−1x02−1−αˉtxt2−2αˉtx0xt+αˉtx02]) 接下来的操作就是比较有技巧性的了。回到最初的问题,我们这一通化简,都是为了求于 x t − 1 \mathbf{x}_{t-1} xt−1有关的条件分布 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}, \mathbf{x}_0}) q(xt−1∣xt,x0)。基于这一直觉,我们把上式的 x t − 1 \mathbf{x}_{t-1} xt−1给提取整理出来,有: q ( x t − 1 ∣ x t , x 0 ) ∝ exp ( − 1 2 [ ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t x t + 2 a ˉ t − 1 1 − α ˉ t − 1 x 0 ) x t − 1 + C ( x t , x 0 ) ] ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}, \mathbf{x}_0}) \propto \exp (-\frac{1}{2}[(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}) \mathbf{x}_{t-1}^2-(\frac{2 \sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{2 \sqrt{\bar{a}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0) \mathbf{x}_{t-1}+C(\mathbf{x}_t, \mathbf{x}_0)]) q(xt−1∣xt,x0)∝exp(−21[(βtαt+1−αˉt−11)xt−12−(βt2αtxt+1−αˉt−12aˉt−1x0)xt−1+C(xt,x0)]) 注意,上面这个式子中, C ( x t , x 0 ) C(\mathbf{x}_t, \mathbf{x}_0) C(xt,x0)其实就是 x t 2 − 2 α ˉ t x 0 x t + α ˉ t x 0 2 1 − α ˉ t \frac{\mathbf{x}_{t}^2 -2\sqrt{\bar{\alpha}_t}\mathbf{x}_0\mathbf{x}_{t} + \bar{\alpha}_t\mathbf{x}_0^2}{1 - \bar{\alpha}_t} 1−αˉtxt2−2αˉtx0xt+αˉtx02,即 q ( x t ∣ x 0 ) q({\mathbf{x}_t} \mid \mathbf{x}_0) q(xt∣x0)。因为上面这步化简的目的是将概率密度函数视为以 x t − 1 \mathbf{x}_{t-1} xt−1为自变量的函数,而 x t 2 − 2 α ˉ t x 0 x t + α ˉ t x 0 2 1 − α ˉ t \frac{\mathbf{x}_{t}^2 -2\sqrt{\bar{\alpha}_t}\mathbf{x}_0\mathbf{x}_{t} + \bar{\alpha}_t\mathbf{x}_0^2}{1 - \bar{\alpha}_t} 1−αˉtxt2−2αˉtx0xt+αˉtx02里面不包含 x t − 1 \mathbf{x}_{t-1} xt−1,所以就将其视为常量 C C C了。
那么上面这个整理的式子究竟有什么用呢?回顾下,以 x x x为自变量的高斯分布 N ( x ; μ , σ 2 ) \mathcal{N}(x; \mu, \sigma^2) N(x;μ,σ2),其概率密度函数正比于 exp ( − 1 2 ( 1 σ 2 x 2 − 2 μ σ 2 x + μ 2 σ 2 ) ) \exp(-\frac{1}{2}\left(\frac{1}{\sigma^2} x^2-\frac{2 \mu}{\sigma^2} x+\frac{\mu^2}{\sigma^2}\right)) exp(−21(σ21x2−σ22μx+σ2μ2))。可以发现,上面式子中 x t − 1 2 \mathbf{x}_{t-1}^2 xt−12与 x t − 1 \mathbf{x}_{t-1} xt−1的系数,其中就包含了 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}, \mathbf{x}_0}) q(xt−1∣xt,x0)这个「高斯分布」中均值与方差的信息。注意,逆向过程 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}, \mathbf{x}_0}) q(xt−1∣xt,x0)与前向过程一样,同样是一种「高斯分布」,但是对其进行证明不在本文的讨论之内,这里直接当做结论来使用。
现在,我们就尝试将 x t − 1 \mathbf{x}_{t-1} xt−1的均值和方差给求出来。根据 N ( x ; μ , σ 2 ) ∝ exp ( − 1 2 ( 1 σ 2 x 2 − 2 μ σ 2 x + μ 2 σ 2 ) ) \mathcal{N}(x; \mu, \sigma^2) \propto \exp(-\frac{1}{2}\left(\frac{1}{\sigma^2} x^2-\frac{2 \mu}{\sigma^2} x+\frac{\mu^2}{\sigma^2}\right)) N(x;μ,σ2)∝exp(−21(σ21x2−σ22μx+σ2μ2)),我们发现,方差 σ 2 \sigma^2 σ2就是 x 2 x^2 x2系数的倒数;而 x t − 1 2 \mathbf{x}^2_{t-1} xt−12的系数为 α t β t + 1 1 − α ˉ t − 1 \frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}} βtαt+1−αˉt−11,可以发现,完全只由人工确定的超参数 α \alpha α和 β \beta β所确定,因此方差是已知的。而对于均值,其值与 x t − 1 \mathbf{x}_{t-1} xt−1的系数 2 α t β t x t + 2 a ˉ t − 1 1 − α ˉ t − 1 x 0 \frac{2 \sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{2 \sqrt{\bar{a}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0 βt2αtxt+1−αˉt−12aˉt−1x0有关。可以发现,除了已知量 α \alpha α, β \beta β, x t \mathbf{x}_t xt,依然包含着我们想要消除的项 x 0 \mathbf{x}_0 x0。
现在,我们将均值 μ \mu μ写成一个关于 x t \mathbf{x}_t xt与 x 0 \mathbf{x}_0 x0的函数,记做 μ ~ t ( x t , x 0 ) \tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right) μ~t(xt,x0)。通过代入 σ 2 \sigma^2 σ2求解 2 μ σ 2 = 2 α t β t x t + 2 a ˉ t − 1 1 − α ˉ t − 1 x 0 \frac{2 \mu}{\sigma^2} = \frac{2 \sqrt{\alpha_t}}{\beta_t} \mathbf{x}_t+\frac{2 \sqrt{\bar{a}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_0 σ22μ=βt2αtxt+1−αˉt−12aˉt−1x0,我们可以得到: μ ~ t ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0 μ~t(xt,x0)=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0 做到这一步,我们已经把求解 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}, \mathbf{x}_0}) q(xt−1∣xt,x0)这一复杂的问题,转化为怎么去求解该分布的均值 μ \mu μ的问题。而要求 μ \mu μ的话,就得想办法把复杂的 x 0 \mathbf{x}_0 x0给消掉或简化,有没有办法把 x 0 \mathbf{x}_0 x0化简成一个更容易看懂的形式呢?
答案是有的。可以发现重要公式2里面有 x 0 \mathbf{x}_0 x0: x t = α ˉ t x 0 + 1 − α ˉ t z ~ t \mathbf{x}_{t} = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\tilde{\mathbf{z}}_t xt=αˉtx0+1−αˉtz~t 我们直接把 x 0 \mathbf{x}_0 x0给移到等式左边来… x 0 = 1 α ˉ t ( x t − 1 − α ˉ t z ~ t ) \mathbf{x}_0=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t} \tilde{\mathbf{z}}_t\right) x0=αˉt1(xt−1−αˉtz~t) 然后把 x 0 \mathbf{x}_0 x0给代回去…
> 重要公式3 < μ ~ t ( x t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t z ~ t ) \boldsymbol{\tilde{\mu}}_t(\mathbf{x}_t)=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1 - \alpha_t}{\sqrt{1-\bar{\alpha}_t}} \tilde{\mathbf{z}}_t\right) μ~t(xt)=αt1(xt−1−αˉt1−αtz~t) > 重要公式3 <
这样就把 x 0 \mathbf{x}_0 x0给消掉了。也就是说,只要知道了 z ~ t \tilde{\mathbf{z}}_t z~t,我们就可以把 μ ~ t \boldsymbol{\tilde{\mu}}_t μ~t给算出来,进而得到 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_{t}}) q(xt−1∣xt),采样出 x t − 1 \mathbf{x}_{t-1} xt−1,完成去噪的过程。
但问题是, z ~ t \tilde{\mathbf{z}}_t z~t本身也是训练阶段,加噪过程中涉及到的东西。在测试阶段,对于一个全新采样的噪声,我们并不知道其是由一张图像与具体哪个高斯噪声给合成出来的(采样有无数种可能)。而且,从数学推导的角度, z ~ t \tilde{\mathbf{z}}_t z~t作为一个噪声,已经非常原子了,没法将其转换成更易获得的形式。
至此,就该请出深度学习了,神经网络最擅长的就是这种人解不出但是可以通过算法去逼近的东西。也就是说,要设计一个网络 ϵ ( x t , t ) \boldsymbol{\epsilon}(\mathbf{x}_t, t) ϵ(xt,t),我们希望其能够预测 z ~ t \tilde{\mathbf{z}}_t z~t。
> 本章小结 <
已知当前图像 x t \mathbf{x}_t xt,获得去噪一步后的图像 x t − 1 \mathbf{x}_{t-1} xt−1的过程,用概率的形式写作 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_t}) q(xt−1∣xt)。用贝叶斯公式对其处理后,我们发现,必须在知道 x 0 \mathbf{x}_0 x0的情况下才能求解 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_t}) q(xt−1∣xt),而 x 0 \mathbf{x}_0 x0本身是去噪的最终目的,因此看起来构成了死循环。所以,我们尝试将 x 0 \mathbf{x}_0 x0进行变形消除,最后发现只要能够求到一个噪声 z ~ t \tilde{\mathbf{z}}_t z~t,就能够对 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \mid {\mathbf{x}_t}) q(xt−1∣xt)进行模拟,完成逆扩散过程。
其实最后网络要预测的是一个噪声,这一结论也非常符合直觉。因为 x t \mathbf{x}_{t} xt本身是加噪声得到的,那么我们如果知道加的噪声是啥,自然能把这一过程反过来。
对应着原文这么一张图:
对于每次迭代:
2 => 随机选择一张图像。从数学的角度讲,叫做从真实图像分布 q ( x 0 ) q(\mathbf{x_0}) q(x0)中采样得到一个样本 x 0 \mathbf{x_0} x0。
3 => 随机选择一个前向步数(加噪声次数) t t t。这个 t t t是从最小步数 1 1 1和最大步数 T T T中随机抽出来的。从数学的角度讲,叫从均匀分布 1 ∼ T 1 \sim T 1∼T中采样。
4 => 随机生成一个标准高斯噪声 ϵ \epsilon ϵ。从数学的角度讲,叫做从标准高斯分布 N ( 0 , I ) \mathcal{N}(0, \mathbf{I}) N(0,I)中采样。
5 => 计算训练时损失(也就是"进行梯度下降步骤")。而 ∣ ∣ a − b ∣ ∣ 2 ||a - b||^2 ∣∣a−b∣∣2其实就是最常见的均方误差损失函数(Mean Square Loss)。
既然是损失函数,肯定就有一个真值和一个网络的预测值。这里的真值就是实时生成的随机噪声 ϵ \epsilon ϵ,而网络预测值则是这么坨东西: ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) \boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}, t\right) ϵθ(αˉtx0+1−αˉtϵ,t) α ˉ t x 0 + 1 − α ˉ t ϵ \sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon} αˉtx0+1−αˉtϵ是什么含义呢?回忆重要公式2: x t = α ˉ t x 0 + 1 − α ˉ t z ~ t \mathbf{x_t} = \sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \tilde{\mathbf{z}}_t xt=αˉtx0+1−αˉtz~t 而这个 z ~ t \tilde{\mathbf{z}}_t z~t和 ϵ \epsilon ϵ同样都是标准高斯噪声。也就是说, α ˉ t x 0 + 1 − α ˉ t ϵ \sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon} αˉtx0+1−αˉtϵ其实就是 x t \mathbf{x_t} xt。至此,损失函数变成了个这样的东西: ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 ||\epsilon - \boldsymbol{\epsilon}_\theta\left(\mathbf{x_t}, t\right)||^2 ∣∣ϵ−ϵθ(xt,t)∣∣2 翻译成人话就是,对于2,3步拿到的原始图像 x 0 \mathbf{x_0} x0和加噪次数 t t t,利用前向过程能够直接推出加噪结果 x t \mathbf{x_t} xt出来。现在有一个网络 ϵ θ \epsilon_\theta ϵθ,我们希望其在输入加噪结果 x t \mathbf{x_t} xt和加噪次数 t t t后,能够预测到一个「合适的」标准高斯噪声,也就是我们在重要公式3中所未知的 z ~ t \tilde{\mathbf{z}}_t z~t。
> 本章小结 <
训练阶段的动机其实是比较难理解的。这里给出我个人的一种解读,可能有误。
一个很难想明白的地方在于,网络为什么要去预测一个标准高斯噪声?直观来讲,这种东西我们直接从标准高斯分布中直接采样就可以了,为什么还要单独设计一个网络去学。要想理解这一点,我们将损失函数的表达式重新展开来,把 ϵ \epsilon ϵ替换成我们熟悉的 z ~ t \tilde{\mathbf{z}}_t z~t;此外,由于训练阶段的 x t \mathbf{x}_t xt是由 x 0 \mathbf{x}_0 x0和 t t t直接求出来的,因此我们也进行相应的替换,最终我们可以把: ∣ ∣ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∣ ∣ 2 ||\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}, t\right)||^2 ∣∣ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)∣∣2
重新改写为: ∣ ∣ z ~ t − ϵ θ ( x 0 , t , z ~ t ) ∣ ∣ 2 ||\tilde{\mathbf{z}}_t - \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_0, t, \tilde{\mathbf{z}}_t\right)||^2 ∣∣z~t−ϵθ(x0,t,z~t)∣∣2 这么写有什么好处呢?我们可以发现一个有趣的事实,在训练阶段,网络去猜测这个 z ~ t \tilde{\mathbf{z}}_t z~t并不是凭空的,而是事实上已经将 z ~ t \tilde{\mathbf{z}}_t z~t和 x 0 \mathbf{x}_0 x0给混在了一起,得到了一个混沌,然后让网络去从混沌中把 z ~ t \tilde{\mathbf{z}}_t z~t给重新"捞出来"。
举个例子就是,与其说 ϵ θ \boldsymbol{\epsilon}_\theta ϵθ是一个所谓的什么去噪网络,不如说是「沙里淘金」:
图像是金,噪声是沙子。在训练阶段,我们把金和沙子混在一起(加噪),让网络学习怎么去把沙子从混合物中给重新分离出来(预测噪声)。至于为什么不是直接把金给拿出来…这是上一章的推导决定的,求噪声要比求图像远远更容易;换句话说,如果是直接淘金,那么网络可能淘个成百上千次,准确率仍然是0,因此很难训练,所以才是淘沙。
从这里发现,网络学到的如何淘沙子的知识,是来源于沙子和金的混合物的,受原有的金(图像)的影响。这就导致,网络在猫图像上训练的去噪网络,对于一个新噪声而言去噪也只能得到各种猫,因为在训练阶段真实分布的信息被嵌入了网络中。
而在测试阶段,相当于仍是有一堆混在一起的金和沙子,这个时候没有标准答案,网络是凭借着自己的训练阶段学到的知识把沙子给淘出来,进而「间接」完成淘金的过程。
对应着原文这么一张图:
1 => 从标准高斯分布中采样得到一个噪声。由于原始图像 x 0 \mathbf{x}_0 x0在加 t t t次噪声后得到的东西也是一个标准高斯噪声,因此这里采样的得到的我们将其记为 x T \mathbf{x}_T xT。
2 => 进行 T T T次逆扩散过程,将图像从高斯噪声 x T \mathbf{x}_T xT中恢复出来。对于每次逆扩散过程:
3 => 随机采样一个标准高斯噪声 z \mathbf{z} z。注意在最后一步的时候我们就不采样了, z = 0 \mathbf{z} = 0 z=0,这算是一个trick…不管这一技巧并不影响对整体的理解。
4 => 通过公式计算得到去噪一次的结果,也就是这么个东西: x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z \mathbf{x}_{t-1}=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right)+\sigma_t \mathbf{z} xt−1=αt1(xt−1−αˉt1−αtϵθ(xt,t))+σtz 这个式子的理解依旧是参考前置知识中的「重参数化技巧」。从分布的角度,比方说,从 N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N(μ,σ2)中采样得到一个 ϵ \epsilon ϵ,写成数学表达式就是: ϵ = μ + z ⋅ σ \epsilon= \mu + \mathbf{z} \cdot \sigma ϵ=μ+z⋅σ 其中 z \mathbf{z} z为标准高斯噪声。根据重要公式3,我们知道高斯分布 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \mid \mathbf{x}_t) q(xt−1∣xt)的均值 μ t \mu_{t} μt为 1 α t ( x t − 1 − α t 1 − α ˉ t z ~ t ) \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1 - \alpha_t}{\sqrt{1-\bar{\alpha}_t}} \tilde{\mathbf{z}}_t\right) αt1(xt−1−αˉt1−αtz~t),再加上方差 σ t 2 \sigma_t^2 σt2(可以由超参数 α \alpha α和 β \beta β直接求得),有:
q ( x t − 1 ∣ x t ) ∼ N ( 1 α t ( x t − 1 − α t 1 − α ˉ t z ~ t , σ t 2 ) q(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}) \sim \mathcal{N}(\frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \tilde{\mathbf{z}}_t, \sigma_t^2) q(xt−1∣xt)∼N(αt1(xt−1−αˉt1−αtz~t,σt2) 而 z ~ t \tilde{\mathbf{z}}_t z~t的话,其实就是网络 ϵ θ ( x t , t ) \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) ϵθ(xt,t)能够预测的东西,直接替换掉就行。从分布 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}) q(xt−1∣xt)中采样得到 x t − 1 \mathbf{x}_{t-1} xt−1,有: x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z \mathbf{x}_{t-1}=\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right)+\sigma_t \mathbf{z} xt−1=αt1(xt−1−αˉt1−αtϵθ(xt,t))+σtz
> 本章小结 <
训练阶段得到的网络能够对 z ~ t \tilde{\mathbf{z}}_t z~t进行预测,从而使得我们能在知道 x t \mathbf{x}_{t} xt的情况下,从分布 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}) q(xt−1∣xt)中采样得到 x t − 1 \mathbf{x}_{t-1} xt−1,逐步完成去噪过程。
个人对diffusion的思想路线可以概括如下:
由于本文为了方便理解起见,去除了大量的推导证明过程,有不足之处还请指正。
❤️ 如果觉得这篇文章对你理解diffusion有所帮助,欢迎在下方点个「赞」和「收藏」。
这里给出了一些对diffusion更加严谨细节推导的热门解读:
https://lilianweng.github.io/posts/2021-07-11-diffusion-models/
https://zhuanlan.zhihu.com/p/525106459
https://wrong.wang/blog/20220605-%E4%BB%80%E4%B9%88%E6%98%AFdiffusion%E6%A8%A1%E5%9E%8B/
https://www.bilibili.com/video/BV1b541197HX