参考:
[1] Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models[J]. Advances in neural information processing systems, 2020, 33: 6840-6851.
[2] 扩散模型/Diffusion Model原理讲解_哔哩哔哩_bilibili
[3] 扩散模型公式推导_扩散模型数学推导-CSDN博客
[4] 扩散模型的一些公式证明_扩散模型公式-CSDN博客
[5] 超详细!!扩散模型基本原理讲解,一文搞懂扩散模型_扩散模型原理-CSDN博客
推荐在简单了解扩散模型原理后再来看本篇文章,加深对理论的理解,本篇只叙述有关扩散模型公式理论的推导~
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) . (1) q(x_t|x_{t-1})=\mathcal{N}(x_t;\sqrt{1-\beta_t}x_{t-1},\beta_tI).\quad\text{(1)} q(xt∣xt−1)=N(xt;1−βtxt−1,βtI).(1)
其中,x0 是mel-spec,xT 是纯高斯分布噪声,β 是个超参数,代表添加噪声的幅度,或者也可以写成这样,这实际上就用到了重参数技巧来写的公式:
x t = 1 − β t x t − 1 + β t ϵ t (2) \mathbf{x}_t=\sqrt{1-\beta_t}\mathbf{x}_{t-1}+\sqrt{\beta_t}\boldsymbol{\epsilon}_t\quad\text{(2)} xt=1−βtxt−1+βtϵt(2)
其中,ε 为从标准高斯分布中采样的噪声。在原扩散模型中,β 会变得越来越大(通过线性调度器或者余弦调度器)引用【深度学习模型】扩散模型(Diffusion Model)基本原理及代码讲解-CSDN博客,即:
所谓的加噪声,就是基于稍微干净的图片计算一个(多维)高斯分布(每个像素点都有一个高斯分布,且均值就是这个像素点的值,方差是预先定义的 ),然后从这个多维分布中抽样一个数据出来,这个数据就是加噪之后的结果。显然,如果方差非常非常小,那么每个抽样得到的像素点就和原本的像素点的值非常接近,也就是加了一个非常非常小的噪声。如果方差比较大,那么抽样结果就会和原本的结果差距较大。
去噪声也是同理,我们基于稍微噪声的图片 计算一个条件分布,我们希望从这个分布中抽样得到的是相比于 更加接近真实图片的稍微干净的图片。我们假设这样的条件分布是存在的,并且也是个高斯分布,那么我们只需要知道均值和方差就可以了。问题是这个均值和方差是无法直接计算的,所以用神经网络去学习近似这样一个高斯分布。
当然,我们还可以将迭代式进行推导,将 xt 写成闭式解(即无需逐步迭代进行计算)。为了方便计算,我们有如下定义:
α t : = 1 − β t , α ˉ t : = ∏ s = 1 t α s (3) \alpha_{t}:=1-\beta_{t},\quad\bar{\alpha}_{t}:=\prod_{s=1}^{t}\alpha_{s}\quad\text{(3)} αt:=1−βt,αˉt:=s=1∏tαs(3)
结合迭代、重参数技巧,我们不难推导出:
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) x t = α ˉ t x 0 + 1 − α ˉ t ϵ (4) \begin{aligned} q(\mathbf{x}_t|\mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t;\sqrt{\bar{\alpha}_t}\mathbf{x}_0,(1-\bar{\alpha}_t)\mathbf{I}) \\ x_t &= \sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon \end{aligned} \quad\text{(4)} q(xt∣x0)xt=N(xt;αˉtx0,(1−αˉt)I)=αˉtx0+1−αˉtϵ(4)
其中,ε 服从标准高斯分布。
去噪的话我们使用 p 来表示反向重建的过程,使用 θ 表示神经网络参数,实际上,重建的过程就是(图引用自[2]):
我们定义 p 为参数化高斯分布,去近似逆向扩散的过程。则逆向过程可以写为:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) (5) p_\theta(x_{t-1}|x_t)=\mathcal{N}(x_{t-1};\mu_\theta(x_t,t),\Sigma_\theta(x_t,t))\quad\text{(5)} pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))(5)
我们需要去通过神经网络去得到平均值和方差,并从中采样。当然,采样这一过程之后还是会通过重参数去改写。那这一逆向过程我们应该如何计算呢?如何计算从 xt 到 xt-1?先来看看已知条件:
x t = α t x t − 1 + 1 − α t ϵ t (6) x t = α ˉ t x 0 + 1 − α ˉ t ϵ (7) x_t=\sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}\epsilon_t\quad\text{(6)} \\ x_t = \sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar{\alpha}_t}\epsilon\quad\text{(7)} \\ xt=αtxt−1+1−αtϵt(6)xt=αˉtx0+1−αˉtϵ(7)
我们还有贝叶斯公式,并添加 x0 条件后,我们可以有:
p ( x t − 1 ∣ x t , x 0 ) = p ( x t ∣ x t − 1 , x 0 ) p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) (8) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,x_0)=\frac{p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1},x_0)p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)}{p(\boldsymbol{x}_t|\boldsymbol{x}_0)}\quad\text{(8)} p(xt−1∣xt,x0)=p(xt∣x0)p(xt∣xt−1,x0)p(xt−1∣x0)(8)
诶!RHS是不是就出现了前向过程?我们之前已经讲过了前向过程的高斯分布,因此我们不难得到:
p ( x t ∣ x t − 1 , x 0 ) = α t x t − 1 + 1 − α t ϵ ∼ N ( α t x t − 1 , 1 − α t ) (9) p ( x t − 1 ∣ x 0 ) = α ‾ t − 1 x 0 + 1 − α ‾ t − 1 ϵ ∼ N ( α ‾ t − 1 x 0 , 1 − α ‾ t − 1 ) (10) p(x_t|x_{t-1},x_0)=\sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}\epsilon\sim\mathcal{N}(\sqrt{\alpha_t}x_{t-1},1-\alpha_t)\quad\text{(9)} \\ p(x_{t-1}|x_0)=\sqrt{\overline{\alpha}_{t-1}}x_0+\sqrt{1-\overline{\alpha}_{t-1}}\epsilon\sim\mathcal{N}(\sqrt{\overline{\alpha}_{t-1}}x_0,1-\overline{\alpha}_{t-1})\quad\text{(10)} p(xt∣xt−1,x0)=αtxt−1+1−αtϵ∼N(αtxt−1,1−αt)(9)p(xt−1∣x0)=αt−1x0+1−αt−1ϵ∼N(αt−1x0,1−αt−1)(10)
根据高斯分布的表达式:
N ( x ; μ , σ 2 ) = 1 2 π σ 2 e − ( x − μ ) 2 2 σ 2 \mathcal{N}(x;\mu,\sigma^2)=\frac1{\sqrt{2\pi\sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} N(x;μ,σ2)=2πσ21e−2σ2(x−μ)2
我们将式(8)改写成高斯分布,带入式(9)(10)后,最后经过一系列化简后我们就可以得到:
p ( x t − 1 ∣ x t , x 0 ) ∼ N ( a t ( 1 − a ˉ t − 1 ) 1 − a ˉ t x t + a ˉ t − 1 ( 1 − a t ) 1 − a ˉ t x 0 , ( 1 − a t 1 − a ˉ t − 1 1 − a ˉ t ) 2 ) (11) p(x_{t-1}|x_t,x_0)\sim N\left(\frac{\sqrt{a_t}(1-\bar{a}_{t-1})}{1-\bar{a}_t}x_t+\frac{\sqrt{\bar{a}_{t-1}}(1-a_t)}{1-\bar{a}_t}x_0,\left(\frac{\sqrt{1-a_t}\sqrt{1-\bar{a}_{t-1}}}{\sqrt{1-\bar{a}_t}}\right)^2\right)\quad\text{(11)} p(xt−1∣xt,x0)∼N(1−aˉtat(1−aˉt−1)xt+1−aˉtaˉt−1(1−at)x0,(1−aˉt1−at1−aˉt−1)2)(11)
不过,这里面有我们的目标 x0,因此我们需要将 x0 给置换掉,而由公式(4)我们可以知道:
x t = α ‾ t x 0 + 1 − α ‾ t ϵ ⟶ x 0 = x t − 1 − α ‾ t ϵ α ‾ t (12) x_t=\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\epsilon\longrightarrow x_0=\frac{x_t-\sqrt{1-\overline{\alpha}_t}\epsilon}{\sqrt{\overline{\alpha}_t}}\quad\text{(12)} xt=αtx0+1−αtϵ⟶x0=αtxt−1−αtϵ(12)
带入式(11)后,我们就最终得到了去噪过程的推导式:
p ( x t − 1 ∣ x t , ϵ ) ∼ N ( 1 α t ( x t − 1 − α t 1 − α ‾ t ϵ ) , ( 1 − a t 1 − a ˉ t − 1 1 − a ˉ t ) 2 ) (13) p(x_{t-1}|x_t,\epsilon)\sim N\left(\frac1{\sqrt{\alpha_t}}(x_t-\frac{1-\alpha_t}{\sqrt{1-\overline{\alpha}_t}}\epsilon),\left(\frac{\sqrt{1-a_t}\sqrt{1-\bar{a}_{t-1}}}{\sqrt{1-\bar{a}_t}}\right)^2\right)\quad\text{(13)} p(xt−1∣xt,ϵ)∼N(αt1(xt−1−αt1−αtϵ),(1−aˉt1−at1−aˉt−1)2)(13)
我们现在就有了去噪过程的推导式,只需要 xt 和 ε 就可以推导出 x_t-1。而 xt 我们是知道的,ε 不知道,那么问题就从原来的输入为 xt,输出 x_t-1,到了如今的输入 xt,预测 ε,并带入推导式中。问题的关键就是去预测噪声 ε。
好的,现在回到扩散模型本身来。我们的目的就是为了去算好 x_t-1 的概率分布,也就是能预测好 ε,计算好分布的平均值和方差。我们应该如何让网络去拟合 ε?我们先来看看我们的目标函数:我们的目的是为了通过神经网络去得到 x0,我们使用 θ 表示神经网络参数,则我们的目标是:
p θ ( x 0 ) : = ∫ p θ ( x 0 : T ) d x 1 : T , p θ ( x 0 : T ) : = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) (14) p_{\theta}(\mathbf{x}_{0}):=\int p_{\theta}(\mathbf{x}_{0:T}) d\mathbf{x}_{1:T}, \quad p_\theta(\mathbf{x}_{0:T}):=p(\mathbf{x}_T)\prod_{t=1}^Tp_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)\quad\text{(14)} pθ(x0):=∫pθ(x0:T)dx1:T,pθ(x0:T):=p(xT)t=1∏Tpθ(xt−1∣xt)(14)
直接积分是肯定不行的,但我们可以变相去求解它的变分下界。在求解变分下界之前,我们先将前向加噪的过程 q 拿来一用。由于咱们是从原图片开始,通过 Markov 链的形式来一点一点进行加噪,我们当然可以将这一过程视为近似后验分布,因此我们就可以定义这样一个基于所有中间加噪图片潜在变量的近似后验分布:
q ( x 1 : T ∣ x 0 ) : = ∏ t = 1 T q ( x t ∣ x t − 1 ) , q ( x t ∣ x t − 1 ) : = N ( x t ; 1 − β t x t − 1 , β t I ) (15) q(\mathbf{x}_{1:T}|\mathbf{x}_0):=\prod_{t=1}^Tq(\mathbf{x}_t|\mathbf{x}_{t-1}),\quad q(\mathbf{x}_t|\mathbf{x}_{t-1}):=\mathcal{N}(\mathbf{x}_t;\sqrt{1-\beta_t}\mathbf{x}_{t-1},\beta_t\mathbf{I})\quad\text{(15)} q(x1:T∣x0):=t=1∏Tq(xt∣xt−1),q(xt∣xt−1):=N(xt;1−βtxt−1,βtI)(15)
有了前向加噪过程 q,也就是近似后验分布,你就既可以通过 KL 散度联系起两个后验分布p和q从而推导出变分下界,也可以通过式(14)(15),加上Jensen不等式来推导出变分下界,结果目标函数(损失函数)就如下:
− log p θ ( x 0 ) ≤ E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) [ − log p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] = E q [ − log p ( x T ) − ∑ t ≥ 1 log p θ ( x t − 1 ∣ x t ) q ( x t ∣ x t − 1 ) ] = : L (16) -\log p_\theta(\mathbf{x}_0)\leq\mathbb{E}_{\mathbf{x_{1:T}}\sim{q(\mathbf{x_{1:T}}|\mathbf{x}_0)}}\left[-\log\frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\right]=\mathbb{E}_q\left[-\log p(\mathbf{x}_T)-\sum_{t\geq1}\log\frac{p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{q(\mathbf{x}_t|\mathbf{x}_{t-1})}\right]=:L\quad\text{(16)} −logpθ(x0)≤Ex1:T∼q(x1:T∣x0)[−logq(x1:T∣x0)pθ(x0:T)]=Eq[−logp(xT)−t≥1∑logq(xt∣xt−1)pθ(xt−1∣xt)]=:L(16)
而这一个损失函数也可以进一步改写,改写的推导过程可以参考扩散模型的一些公式证明_扩散模型公式-CSDN博客,不过我没有自己验证过,结果是和原论文(DDPM)是一样的。改写为:
L = D K L ( q ( x T ∣ x 0 ) ∣ ∣ p ( x T ) ) + ∑ t > 1 D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p ( x t − 1 ∣ x t ) ) − l o g p ( x 0 ∣ x 1 ) (17) L=D_{KL}\left(q(x_T|x_0)||p(x_T)\right)+\sum_{t>1}D_{KL}\left(q(x_{t-1}|x_t,x_0)||p(x_{t-1}|x_t)\right)-logp(x_0|x_1)\quad\text{(17)} L=DKL(q(xT∣x0)∣∣p(xT))+t>1∑DKL(q(xt−1∣xt,x0)∣∣p(xt−1∣xt))−logp(x0∣x1)(17)
损失函数肉眼可见有三项,我们规定从左到右分别为 L_T,L_t-1 和 L_0:
那么这样一下子讨论下来,实际上损失函数只有 L_t-1 这一项最有效。我们来看看怎么计算 L_t-1 这一项。首先这两个分布都是高斯分布,因此我们可以根据高斯分布表达式去简写 KL 散度的公式。有如下两个高斯分布之间的 KL 散度计算公式:
D K L ( N ( μ 1 , σ 1 2 ) ∣ ∣ N ( μ 2 , σ 2 2 ) ) = l o g σ 2 σ 1 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 (18) D_{KL}\left(\mathcal{N}(\mu_1,\sigma_1^2)\left|\right|\mathcal{N}(\mu_2,\sigma_2^2)\right)=log\frac{\sigma_2}{\sigma_1}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac12\quad\text{(18)} DKL(N(μ1,σ12)∣∣N(μ2,σ22))=logσ1σ2+2σ22σ12+(μ1−μ2)2−21(18)
我们将这两个分布用高斯分布的形式表示出来如下:
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I ) (19) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) (20) q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)=\mathcal{N}(\mathbf{x}_{t-1};\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t,\mathbf{x}_0),\tilde{\beta}_t\mathbf{I})\quad\text{(19)} \\ p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)=\mathcal{N}(\mathbf{x}_{t-1};\boldsymbol{\mu}_\theta(\mathbf{x}_t,t),\boldsymbol{\Sigma}_\theta(\mathbf{x}_t,t))\quad\text{(20)} q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)(19)pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))(20)
在原来的模型中,作者将模型预测的方差 Σ 设置成了一个无法训练的仅与时间步相关的变量,即 Σ = σ_t^2 I,
值得一提的是,作者声明从实验结果来看,无论将 σ_t^2 设置为 β_t 还是和式(13)中一样的方差,结果上来看都差不多
而由式(13)我们知道,q 的方差同样也和模型无关(也是只和时间步相关),因此在式(18)中只有平均值参与的计算才有意义,我们将式(17)的 L_t-1 化简后如下:
L t − 1 = 1 2 σ t 2 ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 + C (21) L_{t-1}=\frac1{2\sigma_t^2}\|\tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t,\mathbf{x}_0)-\boldsymbol{\mu}_\theta(\mathbf{x}_t,t)\|^2+C\quad\text{(21)} Lt−1=2σt21∥μ~t(xt,x0)−μθ(xt,t)∥2+C(21)
这里其实可以看出,模型其实就是在预测前向过程后验分布的均值。根据式(13),并且由于 μ tilde 是个前向过程后验均值,受到 x0 的影响,我们结合式(4),有:
μ ~ t ( x t , x 0 ) = 1 α t ( x t ( x 0 , ϵ ) − β t 1 − α ˉ t ϵ ) = 1 α t ( ( α ˉ t x 0 + 1 − α ˉ t ϵ ) − β t 1 − α ˉ t ϵ ) (22) \begin{aligned} \tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t,\mathbf{x}_0) &= \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t(\mathbf{x}_0,\boldsymbol{\epsilon})-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol{\epsilon}\right) \\ &=\frac{1}{\sqrt{\alpha_t}}\left((\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon})-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol{\epsilon}\right) \end{aligned}\quad\text{(22)} μ~t(xt,x0)=αt1(xt(x0,ϵ)−1−αˉtβtϵ)=αt1((αˉtx0+1−αˉtϵ)−1−αˉtβtϵ)(22)
此外,利用式(13),模型预测的均值 μ_θ 也可以写出来,注意这里的 ε 不再是已知的,而是需要模型去预测,我们在式(13)的后续讲解中有交代过。
μ θ ( x t , t ) = 1 α t ( ( α ˉ t x 0 + 1 − α ˉ t ϵ ) − β t 1 − α ˉ t ϵ θ ( x t , t ) ) (23) \boldsymbol{\mu}_\theta(\mathbf{x}_t,t)=\frac{1}{\sqrt{\alpha_t}}\left((\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon})-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\boldsymbol{\epsilon}_\theta(\mathbf{x}_t,t)\right)\quad\text{(23)} μθ(xt,t)=αt1((αˉtx0+1−αˉtϵ)−1−αˉtβtϵθ(xt,t))(23)
将式(22)和式(23)带入式(21)中,并结合式(4)替换掉 xt,规整一下写法,我们就得到了最终的化简结果:
L ( θ ) = E x 0 , ϵ , t [ β t 2 2 σ t 2 α t ( 1 − α ˉ t ) ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 ] (24) L(\theta)=\mathbb{E}_{\mathbf{x}_0,\boldsymbol{\epsilon},t}\left[\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha}_t)}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon},t)\right\|^2\right]\quad\text{(24)} L(θ)=Ex0,ϵ,t[2σt2αt(1−αˉt)βt2 ϵ−ϵθ(αˉtx0+1−αˉtϵ,t) 2](24)
最后,作者又将这个损失函数进一步简化,据说是可以提高采样的质量并且更易于实现:
L s i m p l e ( θ ) : = E t , x 0 , ϵ [ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 ] (25) L_{\mathrm{simple}}(\theta):=\mathbb{E}_{t,\mathbf{x}_0,\boldsymbol{\epsilon}}\Big[\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta(\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon},t)\right\|^2\Big]\quad\text{(25)} Lsimple(θ):=Et,x0,ϵ[ ϵ−ϵθ(αˉtx0+1−αˉtϵ,t) 2](25)
式(25)就是最终的损失函数了。