观看本文之前建议先观看以下文章:
在推导过程中会参考其中中的一些公式,使用到的公式都会标注出来。
Diffusion Models(扩散模型)包含以下三类:
本文以2020年Ho等人的DDPM为例,其包含了前向扩散过程和反向的扩散过程。
其中,前向扩散过程是为了将复杂的分布转化为一个简单的分布。而反向扩散过程则是从简单分布逆转得到复杂分布。
扩散(Diffusion)在热力学中指细小颗粒从高密度区域扩散至低密度区域,在统计领域,扩散则指将复杂的分布转换为一个简单的分布的过程。
Diffusion模型定义了一个概率分布转换模型 T \mathcal{T} T,能将原始数据 x 0 x_0 x0构成的复杂分布 q c o m p l e x q_{\mathrm{complex}} qcomplex转换为一个简单的已知参数的先验分布 p p r i o r p_{\mathrm{prior}} pprior:
x 0 ∼ q c o m p l e x ⟹ T ( x 0 ) ∼ p p r i o r \begin{equation} \mathbf{x}_0 \sim q_\mathrm{complex}⟹\mathcal{T}(\mathbf{x}_0) \sim p_\mathrm{prior} \end{equation} x0∼qcomplex⟹T(x0)∼pprior
具体来说,Diffusion模型提出可以用马尔科夫链(Markov Chain)来构造 T \mathcal{T} T,即定义一系列条件概率分布 q ( x t ∣ x t − 1 ) t ∈ { 1 , 2 , 3... T } q(\mathbf{x}_t \vert \mathbf{x}_{t-1})\quad t\in\{1,2,3...T\} q(xt∣xt−1)t∈{1,2,3...T},将 x 0 \mathbf{x_0} x0依次转换为 x 1 \mathbf{x_1} x1, x 2 \mathbf{x_2} x2 , . . . , x T ,...,\mathbf{x_T} ,...,xT,希望当 T → inf T \rightarrow \inf T→inf时, x T ∼ p prior \mathbf{x}_{T} \sim p_{\text {prior }} xT∼pprior 。
为了简洁和有效,此处的 p prior p_{\text {prior }} pprior 选择高斯分布,因此整个前向扩散过程可以被看作是,在 T T T步内,不断添加少量的高斯噪声到样本中。
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q ( x 1 : T ∣ x 0 ) = ∏ t = 1 q ( x t ∣ x t − 1 ) q ( x T ) = p prior ( x T ) = N ( x T ; 0 , I ) where T → inf \begin{equation} \begin{array}{c} q\left(\mathbf{x}_{t} \vert \mathbf{x}_{t-1}\right)=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{1-\beta_{t}} \mathbf{x}_{t-1}, \beta_{t} \mathrm{I}\right) \\ q\left(\mathbf{x}_{1: T} \vert \mathbf{x}_{0}\right)=\prod_{t=1} q\left(\mathbf{x}_{t} \vert \mathbf{x}_{t-1}\right)\\ q\left(\mathbf{x}_{T}\right)=p_{\text {prior }}\left(\mathbf{x}_{T}\right)=\mathcal{N}\left(\mathbf{x}_{T} ; \mathbf{0}, \mathrm{I}\right) \quad \text { where } T \rightarrow \inf \end{array} \end{equation} q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)q(x1:T∣x0)=∏t=1q(xt∣xt−1)q(xT)=pprior (xT)=N(xT;0,I) where T→inf
即已知 x t − 1 \mathbf{x_{t-1}} xt−1的时候, x t \mathbf{x_t} xt的概率分布为一个平均值为 1 − β t x t − 1 \sqrt{1-\beta_{t}} \mathbf{x}_{t-1} 1−βtxt−1,方差为 β t I \beta_tI βtI的高斯分布。随着 T T T的不断增大,最终数据分布变成了一个简单固定的高斯分布。
然后对公式2使用Diffusion Model(1):预备知识中提到的重参数化技巧(以下Diffusion Model(1):预备知识中的公式用1-xx代替)进行重参数化可以得到:
x t = 1 − β t x t − 1 + β t z t − 1 where z t − 1 ∈ N ( 0 , I ) \begin{equation} \mathbf{x}_{t}=\sqrt{1-\beta_{t}} \mathbf{x}_{t-1}+\sqrt{\beta_{t}} \mathbf{z}_{t-1} \quad \text { where } \mathbf{z}_{t-1} \in \mathcal{N}(0, \mathbf{I}) \end{equation} xt=1−βtxt−1+βtzt−1 where zt−1∈N(0,I)
这一过程即将高斯分布采样的过程变成了将 x t − 1 \mathbf{x_{t-1}} xt−1与标准高斯分布噪声 z \mathbf{z} z混合,扩散率系数 β t \beta_t βt控制融合 x t − 1 \mathbf{x_{t-1}} xt−1分布和标准高斯分布的比例。
设 α t = 1 − β t \alpha_t=1-\beta_t αt=1−βt以及 α ˉ t = ∏ i = 1 t α i \bar{\alpha}_{t}=\prod_{i=1}^{t} \alpha_{i} αˉt=∏i=1tαi,那么公式3就变成了:
x t = α t x t − 1 + 1 − α t z t − 1 ; where z t − 1 , z t − 2 , ⋯ ∼ N ( 0 , I ) = α t ( α t − 1 x t − 2 + 1 − α t − 1 z t − 2 ) + 1 − α t z t − 1 = α t α t − 1 x t − 2 + α t ( 1 − α t − 1 ) z t − 2 + 1 − α t z t − 1 = α t α t − 1 x t − 2 + 1 − α t − 1 α t z ˉ t − 2 ; where z ˉ t − 2 , z ˉ t − 3 , ⋯ ∼ N ( 0 , I ) = α ˉ t x 0 + 1 − α ˉ t z \begin{equation} \begin{array}{rlr} \mathbf{x}_{t} & =\sqrt{\alpha_{t}} {\color{blue}\mathbf{x}_{t-1}}+\sqrt{1-\alpha_{t}} \mathbf{z}_{t-1} & ; \text { where } \mathbf{z}_{t-1}, \mathbf{z}_{t-2}, \cdots \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\ & =\sqrt{\alpha_t}{\color{blue}(\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1- \alpha_{t-1}} z_{t-2})} + \sqrt{1- \alpha_t} z_{t-1} & \\ & =\sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + {\color{red}\sqrt{{\alpha_t}(1- \alpha_{t-1})} z_{t-2} + \sqrt{1- \alpha_t} z_{t-1}} \\ &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + {\color{red}\sqrt{1- \alpha_{t-1}\alpha_t} \bar{z}_{t-2}} & ; \text { where } \bar{\mathbf{z}}_{t-2}, \bar{\mathbf{z}}_{t-3}, \cdots \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\\ & =\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \mathbf{z} & \\ \end{array} \end{equation} xt=αtxt−1+1−αtzt−1=αt(αt−1xt−2+1−αt−1zt−2)+1−αtzt−1=αtαt−1xt−2+αt(1−αt−1)zt−2+1−αtzt−1=αtαt−1xt−2+1−αt−1αtzˉt−2=αˉtx0+1−αˉtz; where zt−1,zt−2,⋯∼N(0,I); where zˉt−2,zˉt−3,⋯∼N(0,I)
其中公式4从第一行到第二行是将 x t − 1 \mathbf{x_{t-1}} xt−1继续利用重参数化技巧展开,而从第三行到第四行利用了当两个高斯分布 N ( 0 , σ 1 2 I ) \mathcal{N}\left(\mathbf{0}, \sigma_{1}^{2} \mathbf{I}\right) N(0,σ12I)和 N ( 0 , σ 2 2 I ) \mathcal{N}\left(\mathbf{0}, \sigma_{2}^{2} \mathbf{I}\right) N(0,σ22I)相加时,新的分布为 N ( 0 , ( σ 1 2 + σ 2 2 ) I ) \mathcal{N}\left(\mathbf{0}, (\sigma_{1}^{2}+\sigma_{2}^{2}) \mathbf{I}\right) N(0,(σ12+σ22)I)的性质。
具体来说, α t ( 1 − α t − 1 ) z t − 2 \sqrt{\alpha_{t}\left(1-\alpha_{t-1}\right)} \mathbf{z}_{t-2} αt(1−αt−1)zt−2的方差为 α t ( 1 − α t − 1 ) \alpha_t(1-\alpha_{t-1}) αt(1−αt−1),而 1 − α t z t − 1 \sqrt{1-\alpha_{t}}\mathbf{z_{t-1}} 1−αtzt−1的方差为 1 − α t 1-\alpha_t 1−αt,因此新分布的方差为 1 − α t α t − 1 1-\alpha_t\alpha_{t-1} 1−αtαt−1。
将公式4写成条件概率的形式可以得到:
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) \begin{equation} \color{red}q\left(\mathbf{x}_{t} \vert \mathbf{x}_{0}\right)=\mathcal{N}\left(\mathbf{x}_{t} ; \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0},\left(1-\bar{\alpha}_{t}\right) \mathbf{I}\right) \end{equation} q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)
此公式十分重要! 它意味着任意一个时刻 t \mathrm{t} t,我们都可以从 x 0 \mathbf{x}_0 x0直接计算得到 x t \mathbf{x}_t xt。
由于 β t ∈ ( 0 , 1 ) \beta_t \in (0,1) βt∈(0,1),所以 α t ∈ ( 0 , 1 ) \alpha_t \in (0,1) αt∈(0,1)。当 t → inf t \rightarrow \inf t→inf时, α ˉ t → 0 \bar{\alpha}_t \rightarrow 0 αˉt→0。可以得到, 1 − β t \sqrt{1-\beta_t} 1−βt和 β t \sqrt{\beta_t} βt作为系数时,保证了当 T → inf T \rightarrow \inf T→inf时, q ( x T = p p r i o r ( x T ) ) = N ( 0 , I ) q(\mathbf{x_T}=p_\mathrm{prior}(\mathbf{x_T}))=\mathcal{N}(0,\mathrm{I}) q(xT=pprior(xT))=N(0,I)。实际上,只要 T T T取一个很大的值,不需要无限次迭代,就可以近似于标准高斯分布。
β t ∈ R \beta_t \in \mathbb{R} βt∈R实际上是一个超参数,提前定义好的,同样 T T T也是一个超参数。如 T T T可设置为200, β t \beta_t βt可以设置为从0.0001到0.02的线性插值作为所有 β \beta β的取值。
以上就是原数据分布到简单先验噪声分布的转换过程 T \mathcal{T} T。值得注意的是,上述整个扩散过程没有出现一个可学习的参数,就可以将任意原始复杂的分布转换为简单先验分布(标准高斯分布)。
通过Diffusion模型的前向过程,复杂的分布 q c o m p l e x q_{\mathrm{complex}} qcomplex被转换为了一个标准高斯分布 p p r i o r p_{\mathrm{prior}} pprior。
Diffusion Model的逆向过程是从 p p r i o r p_\mathrm{prior} pprior中采样一个样本,将其转化为原始数据分布 q c o m p l e x q_\mathrm{complex} qcomplex中的一个样本。
因此类似于上一节的扩散过程,依次从 q ( x t − 1 ∣ x t ) , t ∈ { T , T − 1 , T − 2 , . . . , 0 } q(\mathbf{x}_{t-1}\vert \mathbf{x}_{t}), \, t\in \{T,T-1,T-2,...,0\} q(xt−1∣xt),t∈{T,T−1,T−2,...,0}中采样,Diffusion Model就可以实现从 x T ∼ N ( 0 , I ) \mathbf{x}_T\sim\mathcal{N}(0,\mathrm{I}) xT∼N(0,I)到数据分布 q c o m p l e x q_\mathrm{complex} qcomplex的转换。
不幸的是, q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1}\vert \mathbf{x}_{t}) q(xt−1∣xt)的分布是未知的。而 [Feller等人在1949年](https://projecteuclid.org/ebooks/berkeley-symposium-on-mathematical-statistics-and-probability/Proceedings of the [First] Berkeley Symposium on Mathematical Statistics and Probability/chapter/On the Theory of Stochastic Processes, with Particular Reference to Applications/bsmsp/1166219215)证明连续扩散过程的逆转具有与正向过程相同的分布形式。即当扩散率 β t \beta_t βt足够小,扩散次数足够多时,离散扩散过程接近于连续扩散过程, q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \vert \mathbf{x}_t) q(xt−1∣xt)的分布形式同 q ( x t ∣ x t − 1 ) q(\mathbf{x}_{t} \vert \mathbf{x}_{t-1}) q(xt∣xt−1)一致,同样是高斯分布。
但是我们依然很难直接写出 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1}\vert \mathbf{x}_{t}) q(xt−1∣xt)的分布参数。为此,我们需要学习一个模型 p θ ( x t − 1 ∣ x t ) p_\theta(\mathbf{x}_{t-1}\vert \mathbf{x}_t) pθ(xt−1∣xt)来近似 q ( x t − 1 ∣ x t ) \color{red}q(\mathbf{x}_{t-1}\vert \mathbf{x}_{t}) q(xt−1∣xt):
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p θ ( x 0 ) = ∫ p θ ( x 0 : T ) d x 1 : T p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) \begin{equation} \begin{aligned} p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right), \mathbf{\Sigma}_{\theta}\left(\mathbf{x}_{t}, t\right)\right)\\ p_{\theta}\left(\mathbf{x}_{0}\right)=\int p_{\theta}\left(\mathbf{x}_{0: T}\right) d \mathbf{x}_{1: T}\\ p_{\theta}\left(\mathbf{x}_{0: T}\right)=p\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right) \end{aligned} \end{equation} pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))pθ(x0)=∫pθ(x0:T)dx1:Tpθ(x0:T)=p(xT)t=1∏Tpθ(xt−1∣xt)
其中,这个高斯分布的均值 μ θ ( x t , t ) \mu_\theta(\mathbf{x}_t,t) μθ(xt,t)以及方差 Σ θ ( x t , t ) \mathbf{\Sigma}_{\theta}(\mathbf{x}_{t}, t) Σθ(xt,t)是需要学习的。
在有了 μ θ \mu_\theta μθ以及 Σ θ \mathbf{\Sigma}_\theta Σθ以后,就得到了 p θ ( x t − 1 ∣ x t ) p_\theta(\mathbf{x}_{t-1}\vert \mathbf{x}_t) pθ(xt−1∣xt)的分布,因此就可以完成整个逆转过程。首先从 N ( 0 , I ) \mathcal{N}(0, \mathrm{I}) N(0,I)中采样得到,然后在 x T \mathbf{x}_T xT以 μ θ ( x T , T ) \mu_\theta(\mathbf{x}_T,T) μθ(xT,T)为均值, Σ θ ( x T , T ) \mathbf{\Sigma}_{\theta}(\mathbf{x}_{T}, T) Σθ(xT,T)为方差的正态分布中采样得到 X T − 1 \mathbf{X}_{T-1} XT−1。然后重复这个过程,直到得到最终结果 x 0 \mathbf{x}_0 x0。
由于 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \vert \mathbf{x}_t) q(xt−1∣xt)未知,所以在逆转Diffusion过程中,用学习到的代替 p θ ( x t − 1 ∣ x t ) p_\theta(\mathbf{x}_{t-1}\vert \mathbf{x}_t) pθ(xt−1∣xt)它。
虽然 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \vert \mathbf{x}_t) q(xt−1∣xt)不容易得到,但是我们可以使用它的后验 q ( x t − 1 ∣ x t , x 0 ) \color{blue}q(\mathbf{x}_{t-1}\vert \mathbf{x}_t,\mathbf{x}_0) q(xt−1∣xt,x0)来替换它,由于逆向过程也是一个Markov Chain,因此 x 0 \mathbf{x}_0 x0是否存在并不会影响 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \vert \mathbf{x}_t) q(xt−1∣xt)。
逆向过程中高斯的后验概率定义:
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ ( x t , x 0 ) , β ~ t I ) \begin{equation} q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_{0}\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \tilde{\boldsymbol{\mu}}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right), \tilde{\beta}_{t} \mathbf{I}\right) \end{equation} q(xt−1∣xt,x0)=N(xt−1;μ~(xt,x0),β~tI)
使用后验概率,增加 x 0 \mathbf{x}_0 x0的原因有以下两个:
通过对后验概率 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}\vert \mathbf{x}_t,\mathbf{x}_0) q(xt−1∣xt,x0)使用公式1-8 的贝叶斯公式可以得到:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) ∝ exp ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) = exp ( − 1 2 ( ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t x t + 2 α ˉ t − 1 1 − α ˉ t − 1 x 0 ) x t − 1 + C ( x t , x 0 ) ) ) \begin{equation} \begin{aligned} q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_{0}\right) &=q\left(\mathbf{x}_{t} \vert \mathbf{x}_{t-1}, \mathbf{x}_{0}\right) \frac{q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t} \vert \mathbf{x}_{0}\right)} \\ &=q\left(\mathbf{x}_{t} \vert \mathbf{x}_{t-1}\right) \frac{q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t} \vert \mathbf{x}_{0}\right)} \\ & \propto \exp \left(-\frac{1}{2}\left(\frac{\left(\mathbf{x}_{t}-\sqrt{\alpha_{t}} \mathbf{x}_{t-1}\right)^{2}}{\beta_{t}}+\frac{\left(\mathbf{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_{0}\right)^{2}}{1-\bar{\alpha}_{t-1}}-\frac{\left(\mathbf{x}_{t}-\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}\right)^{2}}{1-\bar{\alpha}_{t}}\right)\right) \\ &=\exp \left(-\frac{1}{2}\left({\color{red}(\frac{\alpha_{t}}{\beta_{t}}+\frac{1}{1-\bar{\alpha}_{t-1}}) \mathbf{x}_{t-1}^{2}}-{\color{blue}(\frac{2 \sqrt{\alpha_{t}}}{\beta_{t}} \mathbf{x}_{t}+\frac{2 \sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}} \mathbf{x}_{0}) \mathbf{x}_{t-1}}+{\color{blue}C\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)}\right)\right) \end{aligned} \end{equation} q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)=q(xt∣xt−1)q(xt∣x0)q(xt−1∣x0)∝exp(−21(βt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2))=exp(−21((βtαt+1−αˉt−11)xt−12−(βt2αtxt+1−αˉt−12αˉt−1x0)xt−1+C(xt,x0)))
其中公式8第二行中结合公式2和公式5可以得到第三行:
q ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , β t I ) ∝ e x p ( − 1 2 ( x t − α t x t − 1 ) 2 β t ) q ( x t − 1 ∣ x 0 ) = N ( x t − 1 ; α ˉ t − 1 x 0 , ( 1 − α ˉ t − 1 ) I ) ∝ exp ( − 1 2 ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 ) q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) ∝ exp ( − 1 2 ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) \begin{aligned} q(\mathbf{x}_t \vert \mathbf{x}_{t-1})&=\mathcal{N}(\mathbf{x}_t; \sqrt{\alpha}_t\mathbf{x}_{t-1},\beta_t\mathbf{I}) \propto \mathrm{exp}(-\frac{1}{2}\frac{(\mathbf{x}_t-\sqrt{\alpha_t}\mathbf{x}_{t-1})^2}{\beta_t})\\ q(\mathbf{x}_{t-1} \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_{t-1}; \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0, (1 - \bar{\alpha}_{t-1})\mathbf{I}) \propto \exp(-\frac{1}{2}\frac{(\mathbf{x}_{t-1} - \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0)^2}{1 - \bar{\alpha}_{t-1}})\\ q(\mathbf{x}_{t} \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_{t}; \sqrt{\bar{\alpha}_{t}}\mathbf{x}_0, (1 - \bar{\alpha}_{t})\mathbf{I}) \propto \exp(-\frac{1}{2}\frac{(\mathbf{x}_{t} - \sqrt{\bar{\alpha}_{t}}\mathbf{x}_0)^2}{1 - \bar{\alpha}_{t}}) \end{aligned} q(xt∣xt−1)q(xt−1∣x0)q(xt∣x0)=N(xt;αtxt−1,βtI)∝exp(−21βt(xt−αtxt−1)2)=N(xt−1;αˉt−1x0,(1−αˉt−1)I)∝exp(−211−αˉt−1(xt−1−αˉt−1x0)2)=N(xt;αˉtx0,(1−αˉt)I)∝exp(−211−αˉt(xt−αˉtx0)2)
而其中的第四行可以利用 a x 2 + b x + C = a ( x + b 2 a ) 2 ax^2 + bx + C = a(x + \frac{b}{2a})^2 ax2+bx+C=a(x+2ab)2公式,将其凑成高斯分布概率密度的形式。
因此,我们可以得到 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}\vert \mathbf{x}_t,\mathbf{x}_0) q(xt−1∣xt,x0)的高斯概率密度表示为:
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ ( x t , x 0 ) , β t ~ I ) ≈ exp ( − ( x − μ ~ ( x t , x 0 ) ) 2 2 β ~ t ) \begin{equation} q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0)= \mathcal{N}(\mathbf{x}_{t-1}; {\color{blue}\tilde{\mu}(\mathbf{x}_t, \mathbf{x}_0)},{\color{red} \tilde{\beta_t}\mathbf{I})} \approx \exp \left( -\frac{(\mathbf{x} - \tilde{\mu}(\mathbf{x}_t, \mathbf{x}_0))^2}{2\tilde{\beta}_t} \right) \end{equation} q(xt−1∣xt,x0)=N(xt−1;μ~(xt,x0),βt~I)≈exp(−2β~t(x−μ~(xt,x0))2)
然后,使用公式4,将其中的 x 0 = 1 α t ˉ ( x t − 1 − α ˉ t z t ) \mathbf{x}_0 = \frac{1}{\sqrt{\bar{\alpha_t}}} (\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t} z_t) x0=αtˉ1(xt−1−αˉtzt)替换为 x t \mathbf{x}_t xt,因此公式9中的均值 μ ~ t ( x t , x 0 ) \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) μ~t(xt,x0),对其进行推导可以得到:
μ ~ t ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t ⋅ 1 α ˉ t ( x t − 1 − α ˉ t z t ) = α t ⋅ α t ( 1 − α ˉ t − 1 ) α t ⋅ ( 1 − α ˉ t ) x t + α ˉ t − 1 β t 1 − α ˉ t ⋅ 1 α ˉ t ( x t − 1 − α ˉ t z t ) = α t − α ˉ t α t ( 1 − α ˉ t ) x t + β t ( 1 − α ˉ t ) α t ( x t − 1 − α ˉ t z t ) = 1 − α ˉ t α t ( 1 − α ˉ t ) x t − β t ( 1 − α ˉ t ) α t ( 1 − α ˉ t z t ) = 1 α t x t − β t ( 1 − α ˉ t ) α t z t = 1 α t ( x t − β t ( 1 − α ˉ t ) z t ) \begin{equation} \begin{aligned} \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0) &= \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}\mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_{t}} {\color{red}\mathbf{x}_0} \\ &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_{t}}\mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_{t}} \cdot {\color{red}\frac{1}{\sqrt{\bar{\alpha}_t}} (\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t} z_t)} \\ & = \frac{{\color{green}\sqrt{\alpha_t}}\cdot\sqrt{\alpha_t}(1-{\color{green}\bar{\alpha}_{t-1}})}{{ {\color{green}\sqrt{\alpha_t}}}\cdot( 1-\bar{\alpha}_{t})}\mathbf{x}_t + \frac{{\color{blue}\sqrt{\bar{\alpha}_{t-1}}}\beta_t}{1-\bar{\alpha}_{t}} \cdot \frac{1}{\color{blue}\sqrt{\bar{\alpha}_t}} (x_t - \sqrt{1 - \bar{\alpha}_t} z_t) \\ & = \frac{{\color{purple}\alpha_t}-\color{green}\bar{\alpha}_{t}}{{\sqrt{\alpha_t}}(1-\bar{\alpha}_{t})}\mathbf{x}_t + \frac{\color{purple}\beta_t}{(1-\bar{\alpha}_{t})\color{blue}\sqrt{{\alpha_t}}} (x_t - \sqrt{1 - \bar{\alpha}_t} z_t) \\ & = \frac{{\color{purple}1}-\bar{\alpha}_{t}}{{\sqrt{\alpha_t}}(1-\bar{\alpha}_{t})}\mathbf{x}_t- \frac{\beta_t}{(1-\bar{\alpha}_{t})\sqrt{{\alpha_t}}} (\sqrt{1 - \bar{\alpha}_t} z_t) \\ & = \frac{1}{{\sqrt{\alpha_t}}}\mathbf{x}_t - \frac{\beta_t}{\sqrt{(1-\bar{\alpha}_{t})}\sqrt{{\alpha_t}}} z_t \\ & = \color{brown}\frac{1}{{\sqrt{\alpha_t}}}\big(\mathbf{x}_t - \frac{\beta_t}{\sqrt{(1-\bar{\alpha}_{t})}} z_t \big) \\ \end{aligned} \end{equation} μ~t(xt,x0)=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βt⋅αˉt1(xt−1−αˉtzt)=αt⋅(1−αˉt)αt⋅αt(1−αˉt−1)xt+1−αˉtαˉt−1βt⋅αˉt1(xt−1−αˉtzt)=αt(1−αˉt)αt−αˉtxt+(1−αˉt)αtβt(xt−1−αˉtzt)=αt(1−αˉt)1−αˉtxt−(1−αˉt)αtβt(1−αˉtzt)=αt1xt−(1−αˉt)αtβtzt=αt1(xt−(1−αˉt)βtzt)
上述公式的第四行到第五行的变换用到了 β t = 1 − α t \beta_t=1-\alpha_t βt=1−αt
到此,我们在逆向过程中的目标就变成了拉近以下两个高斯分布的距离,这可以通过计算两个分布的KL散度实现,其中 q ( x t − 1 ∣ x t , x 0 ) q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_{0}\right) q(xt−1∣xt,x0)的均值和方差都是已知的:
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ ( x t , x 0 ) , β ~ t I ) ⟷ p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) \begin{equation} q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_{0}\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \tilde{\boldsymbol{\mu}}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right), \tilde{\beta}_{t} \mathbf{I}\right) \longleftrightarrow p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right), \boldsymbol{\Sigma}_{\theta}\left(\mathbf{x}_{t}, t\right)\right) \end{equation} q(xt−1∣xt,x0)=N(xt−1;μ~(xt,x0),β~tI)⟷pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
References: