奠基性的工作:
条件概率的一般形式
P ( B , C ∣ A ) = P ( B ∣ A ) P ( C ∣ A , B ) P(B,C|A)=P(B|A)P(C|A,B) P(B,C∣A)=P(B∣A)P(C∣A,B)
基于马尔可夫假设的条件概率
假设马尔可夫链关系 A → B → C A\to B\to C A→B→C,有
P ( A , B , C ) = P ( C ∣ B ) P ( B ∣ A ) P ( A ) P(A,B,C)=P(C|B)P(B|A)P(A) P(A,B,C)=P(C∣B)P(B∣A)P(A)
高斯分布的KL散度
对于两个单一变量的高斯分布p和q而言,他们的KL散度满足
K L ( N ( μ 1 , σ 1 2 ) , N ( μ 2 , σ 2 2 ) ) = log σ 2 σ 1 − 1 2 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 KL(\mathcal{N}(\mu_1,\sigma_1^2),\mathcal{N}(\mu_2,\sigma_2^2))=\log\frac{\sigma_2}{\sigma_1}-\frac{1}{2}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} KL(N(μ1,σ12),N(μ2,σ22))=logσ1σ2−21+2σ22σ12+(μ1−μ2)2
推导详见CSDN博客
参数重整化
若希望从高斯分布 N ( μ , σ 2 ) \mathcal{N}(\mu,\sigma^2) N(μ,σ2)中采样,可以先从标准分布 N ( 0 , 1 ) \mathcal{N}(0,1) N(0,1)得到 z z z,得到 σ ⋅ z + μ \sigma\cdot z+\mu σ⋅z+μ。
这样就可以将 σ \sigma σ和 μ \mu μ也作为仿射网络的一部分,而不是不可导的环境参数。
这个技巧在VAE和Diffusion model中大量被使用。
x → z , q ϕ ( z ∣ x ) z → x , p θ ( x ∣ z ) x\to z,\quad q_{\phi}(z|x)\\ z\to x,\quad p_{\theta}(x|z) x→z,qϕ(z∣x)z→x,pθ(x∣z)
此时 x x x的边缘概率分布可以改写为关于z的期望式
p ( x ) = ∫ z p θ ( x ∣ z ) p ( z ) d z = ∫ z q ϕ ( z ∣ x ) p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) d z = E z ∼ q ϕ ( z ∣ x ) p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) \begin{aligned} p(x)&=\int_zp_\theta(x|z)p(z)\text{d}z\\ &=\int_zq_\phi(z|x)\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)}\text{d}z\\ &=\mathbb{E}_{z\sim q_\phi(z|x)}\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)} \end{aligned} p(x)=∫zpθ(x∣z)p(z)dz=∫zqϕ(z∣x)qϕ(z∣x)pθ(x∣z)p(z)dz=Ez∼qϕ(z∣x)qϕ(z∣x)pθ(x∣z)p(z)
此时的Evidence存在一个lower bound(ELBO)
log p ( x ) = log E z ∼ q ϕ ( z ∣ x ) p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) ≥ E z ∼ q ϕ ( z ∣ x ) log [ p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) ] \log p(x)=\log\mathbb{E}_{z\sim q_\phi(z|x)}\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)} \ge\mathbb{E}_{z\sim q_\phi(z|x)}\log\left[\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)}\right] logp(x)=logEz∼qϕ(z∣x)qϕ(z∣x)pθ(x∣z)p(z)≥Ez∼qϕ(z∣x)log[qϕ(z∣x)pθ(x∣z)p(z)]
在训练中,我们需要最大化对数似然,即Evidence,可以通过最小化lower bound实现,而这个lower bound可以分为两部分:
所以,单层VAE的损失函数是可优化的。
基于同样的原理,
p ( x ) = ∫ z 1 ∫ z 2 p θ ( x , z 1 , z 2 ) d z 1 d z 2 = ∫ z 1 ∫ z 2 q ϕ ( z 1 , z 2 ∣ x ) p θ ( x , z 1 , z 2 ) q ϕ ( z 1 , z 2 ∣ x ) d z 1 d z 2 = E z 1 , z 2 ∼ q ϕ ( z 1 , z 2 ∣ x ) p θ ( x , z 1 , z 2 ) q ϕ ( z 1 , z 2 ∣ x ) \begin{aligned} p(x)&=\int_{z_1}\int_{z_2}p_\theta(x,z_1,z_2)\text{d}z_1\text{d}z_2\\ &=\int_{z_1}\int_{z_2}q_\phi(z_1,z_2|x)\frac{p_\theta(x,z_1,z_2)}{q_\phi(z_1,z_2|x)}\text{d}z_1\text{d}z_2\\ &=\mathbb{E}_{z1,z_2\sim q_\phi(z_1,z_2|x)}\frac{p_\theta(x,z_1,z_2)}{q_\phi(z_1,z_2|x)} \end{aligned} p(x)=∫z1∫z2pθ(x,z1,z2)dz1dz2=∫z1∫z2qϕ(z1,z2∣x)qϕ(z1,z2∣x)pθ(x,z1,z2)dz1dz2=Ez1,z2∼qϕ(z1,z2∣x)qϕ(z1,z2∣x)pθ(x,z1,z2)
得到
log p ( x ) ≥ E z 1 , z 2 ∼ q ϕ ( z 1 , z 2 ∣ x ) log p θ ( x , z 1 , z 2 ) q ϕ ( z 1 , z 2 ∣ x ) \log p(x)\ge \mathbb{E}_{z1,z_2\sim q_\phi(z_1,z_2|x)}\log \frac{p_\theta(x,z_1,z_2)}{q_\phi(z_1,z_2|x)} logp(x)≥Ez1,z2∼qϕ(z1,z2∣x)logqϕ(z1,z2∣x)pθ(x,z1,z2)
如果上述过程满足马尔科夫假设,即
p θ ( x , z 1 , z 2 ) = p ( x ∣ z 1 ) p ( z 1 ∣ z 2 ) p ( z 2 ) q ( z 1 , z 2 ∣ x ) = q ( z 1 ∣ x ) q ( z 2 ∣ z 1 ) p_\theta(x,z_1,z_2)=p(x|z_1)p(z_1|z_2)p(z_2)\\ q(z_1,z_2|x)=q(z_1|x)q(z_2|z_1) pθ(x,z1,z2)=p(x∣z1)p(z1∣z2)p(z2)q(z1,z2∣x)=q(z1∣x)q(z2∣z1)
(6)式能够被进一步简化为
L ( θ , ϕ ) = E q ( z 1 , z 2 ∣ x ) [ log p ( x ∣ z 1 ) − log q ( z 1 ∣ x ) + log p ( z 1 ∣ z 2 ) − log q ( z 2 ∣ z 1 ) + log p ( z 2 ) ] \mathcal{L}(\theta,\phi)=\mathbb{E}_{q(z_1,z_2|x)} \left[ \log p(x|z_1)-\log q(z_1|x)+\log p(z_1|z_2) -\log q(z_2|z_1) +\log p(z_2) \right] L(θ,ϕ)=Eq(z1,z2∣x)[logp(x∣z1)−logq(z1∣x)+logp(z1∣z2)−logq(z2∣z1)+logp(z2)]
从右往左,从目标分布到噪声分布称为扩散过程,而我们希望学习到从左往右的逆扩散过程。上图中的第一行从左往右是扩散过程,第二行从右往左是逆扩散过程,而第三行是前两者的差值,称为偏移量。
给定初始数据分布 x 0 ∼ q ( x ) \bold{x_0}\sim q(\bold{x}) x0∼q(x),不断向分布中添加高斯噪声,噪声的标准差是以 β t \beta_t βt确定的,均值是以固定值 β t \beta_t βt和当前时刻的数据 x t \bold{x_t} xt决定的,所以该过程并没有需要学习的参数,而且是一个马尔科夫链过程。
随着 t t t的不断增大,最终数据分布 x T x_T xT变成了一个各项独立的高斯分布
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(\bold{x_t|x_{t-1}})=\mathcal{N}(\bold{x_t};\sqrt{1-\beta_t}\bold{x_{t-1},\beta_t\bold{I}}) q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
q ( x 1 : T ∣ x o ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(\bold{x_{1:T}|x_o})=\prod^{T}_{t=1}q(\bold{x_t|x_{t-1}}) q(x1:T∣xo)=t=1∏Tq(xt∣xt−1)
这充分体现了参数重整化的技巧。
任意时刻的 q ( x t ) q(\bold{x_t}) q(xt)推导也可以完全基于 x 0 \bold{x}_0 x0和 β t \beta_t βt计算得到闭式解,而不需要做迭代。(令 α t = 1 − β t \alpha_t=1-\beta_t αt=1−βt)
两个正态分布 X ∼ N ( μ 1 , σ 1 ) X\sim \mathcal{N}(\mu_1,\sigma_1) X∼N(μ1,σ1)和 Y ∼ N ( μ 2 , σ 2 ) Y\sim \mathcal{N}(\mu_2,\sigma_2) Y∼N(μ2,σ2)叠加后的分布 a X + b Y aX+bY aX+bY服从分布 N ( a μ 1 + b μ 2 , a 2 σ 1 2 + b 2 σ 2 2 ) \mathcal{N}(a\mu_1+b\mu_2,a^2\sigma_1^2+b^2\sigma_2^2) N(aμ1+bμ2,a2σ12+b2σ22)。
对于第 t t t步的分布 x t x_t xt等于上一步的分布 x t − 1 x_{t-1} xt−1加上高斯噪声 z t − 1 z_{t-1} zt−1,即
x t = α t x t − 1 + 1 − α t z t − 1 ; where z t − 1 , z t − 2 , . . . ∼ N ( 0 , I ) = α t α t − 1 x t − 2 + α t − α t α t − 1 z t − 2 + 1 − α t z t − 1 \begin{aligned} \bold{x}_t&=\sqrt{\alpha_t}\bold{x}_{t-1}+\sqrt{1-\alpha_t}\bold{z}_{t-1}\qquad ;\text{where} \ \bold{z_{t-1}},\bold{z_{t-2}},...\sim \mathcal{N}(\bold{0},\bold{I})\\ &=\sqrt{\alpha_t\alpha_{t-1}}\bold{x}_{t-2}+{\color{red} \sqrt{\alpha_t-\alpha_t\alpha_{t-1}}\bold{z}_{t-2}+\sqrt{1-\alpha_t}\bold{z_{t-1}}} \end{aligned} xt=αtxt−1+1−αtzt−1;where zt−1,zt−2,...∼N(0,I)=αtαt−1xt−2+αt−αtαt−1zt−2+1−αtzt−1
这里借助参数重整化的技巧,将红色部分的两个高斯分布合并为新的高斯分布,整理如下所示
x t = α t α t − 1 x t − 2 + 1 − α t α t − 1 z ˉ t − 2 \begin{aligned} \bold{x}_t&=\sqrt{\alpha_t\alpha_{t-1}}\bold{x}_{t-2}+{\color{red} \sqrt{1-\alpha_t\alpha_{t-1}}\bar{\bold{z}}_{t-2}} \end{aligned} xt=αtαt−1xt−2+1−αtαt−1zˉt−2
其中, z ˉ t − 2 ∼ N ( 0 , I ) \bar{\bold{z}}_{t-2}\sim \mathcal{N}(\bold{0},\bold{I}) zˉt−2∼N(0,I)
重复上面的步骤,最终可以得到 z t \bold{z}_t zt的闭式解
x t = α ˉ t x 0 + 1 − α ˉ t z ; where α ˉ t = ∏ i = 1 T α i \bold{x}_t=\sqrt{\bar{\alpha}_t}\bold{x}_0+\sqrt{1-\bar{\alpha}_{t}}\bold{z}\qquad ;\text{where}\ \bar{\alpha}_t=\prod_{i=1}^T\alpha_i xt=αˉtx0+1−αˉtz;where αˉt=i=1∏Tαi
此时,作者认为 x t ∼ N ( x t ; α ˉ t x 0 , 1 − α ˉ t I ) \bold{x}_t\sim \mathcal{N}(\bold{x}_t;\sqrt{\bar{\alpha}_t}\bold{x}_0,\sqrt{1-\bar{\alpha}_t}\bold{I}) xt∼N(xt;αˉtx0,1−αˉtI),(此处应该是认为 x 0 \bold{x}_0 x0是完全已知的,方差为零),最终当上述分布趋近于 N ( 0 , I ) \mathcal{N}(\bold{0},\bold{I}) N(0,I)的时候,可认为模型已经基本完成扩散过程。因此,作者给出了一种 β t \beta_t βt的设置经验, β 1 < β 2 < ⋅ ⋅ ⋅ < β T \beta_1<\beta_2<\cdot\cdot\cdot<\beta_T β1<β2<⋅⋅⋅<βT,即随着扩散深度的加深,逐步扩大 β \beta β。
逆过程是从高斯分布中恢复原始数据,当 β t \beta_t βt足够小时,逆过程的每一小步 p θ ( x t − 1 ∣ x t ) p_\theta(\bold{x}_{t-1}|\bold{x}_t) pθ(xt−1∣xt)也可视作高斯分布,即
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) p_\theta(\bold{x}_{t-1}|\bold{x}_t)=\mathcal{N}(\bold{x}_{t-1};\bold{\mu_\theta}(\bold{x}_t,t),\sum_\theta(\bold{x}_t,t)) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),θ∑(xt,t))
逆扩散过程可以被总结为如下形式
p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p_\theta(\bold{x}_{0:T})=p(\bold{x}_T)\prod_{t=1}^Tp_\theta (\bold{x}_{t-1}|\bold{x}_t) pθ(x0:T)=p(xT)t=1∏Tpθ(xt−1∣xt)
此处通过使用网络估计参数 θ \theta θ以实现逆扩散过程。
根据条件概率的贝叶斯公式
q ( x t − 1 ∣ x t , x 0 ) q ( x t ∣ x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0)q(\bold{x}_t|\bold{x}_0)=q(\bold{x}_{t}|\bold{x}_{t-1},\bold{x}_0)q(\bold{x}_{t-1}|\bold{x}_0) q(xt−1∣xt,x0)q(xt∣x0)=q(xt∣xt−1,x0)q(xt−1∣x0)
得到
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) ∝ exp ( − 1 2 ( ( x t − α t x t − 1 ) 2 1 − α t + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) = exp ( − 1 2 ( a x t − 1 2 + b x t − 1 + c ( x t , x 0 ) ) ) \begin{aligned} q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0)&=q(\bold{x}_{t}|\bold{x}_{t-1},\bold{x}_0)\frac{q(\bold{x}_{t-1}|\bold{x}_0)}{q(\bold{x}_t|\bold{x}_0)}\\ &\propto \exp \left(-\frac{1}{2}\left(\frac{(\bold{x}_t-\sqrt{\alpha_t}\bold{x}_{t-1})^2}{1-\alpha_t}+\frac{(\bold{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}}\bold{x}_0)^2}{1-\bar{\alpha}_{t-1}}-\frac{(\bold{x}_{t}-\sqrt{\bar{\alpha}_{t}}\bold{x}_0)^2}{1-\bar{\alpha}_{t}}\right)\right)\\ &=\exp\left(-\frac{1}{2}\left({\color{blue} a}\bold{x}_{t-1}^2+{\color{red} b}\bold{x}_{t-1}+c(\bold{x}_t,\bold{x}_0)\right)\right) \end{aligned} q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0)∝exp(−21(1−αt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2))=exp(−21(axt−12+bxt−1+c(xt,x0)))
可见,上述分布的核心可以用一个二次函数来描述,那对应的中轴线应该是
μ ~ t ( x t , x 0 ) = − b 2 a = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ x 0 \begin{aligned} \bold{\tilde\mu_t}(\bold{x}_t,\bold{x}_0)&=-\frac{\color{red}{b}}{2\color{blue}{a}}\\ &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\bold{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}}\bold{x}_0 \end{aligned} μ~t(xt,x0)=−2ab=1−αˉtαt(1−αˉt−1)xt+1−αˉαˉt−1βtx0
容易从扩散过程的表达式(式11)得到 x 0 \bold{x}_0 x0的表达式
x 0 = 1 α ˉ t ( x t − 1 − α ˉ t z ) \bold{x}_0=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\bold{x}_t-\sqrt{1-\bar{\alpha}_{t}}\bold{z}\right) x0=αˉt1(xt−1−αˉtz)
带入得到
μ ~ t ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ 1 α ˉ t ( x t − 1 − α ˉ t z ) = 1 α t ( x t − β t 1 − α ˉ t z t ) \begin{aligned} \bold{\tilde\mu_t}(\bold{x}_t,\bold{x}_0) &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\bold{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}}\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\bold{x}_t-\sqrt{1-\bar{\alpha}_{t}}\bold{z}\right)\\ &=\color{green}{\frac{1}{\sqrt{\alpha}_t}\left(\bold{x}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\bold{z}_t\right)} \end{aligned} μ~t(xt,x0)=1−αˉtαt(1−αˉt−1)xt+1−αˉαˉt−1βtαˉt1(xt−1−αˉtz)=αt1(xt−1−αˉtβtzt)
这就是 x t − 1 \bold{x}_{t-1} xt−1分布的均值表达式,即给定 x 0 \bold{x}_0 x0的条件下,后验条件高斯分布的均值计算只与 x t \bold{x}_{t} xt和 z t \bold{z}_t zt有关。
我们在待优化的目标数据分布的似然函数(负)上加一个非负的KL散度,构成负对数似然的上界,通过最小化上界,负对数似然自然取得最小。
− log p θ ( x 0 ) = − log p θ ( x 0 ) + D K L ( q ( x 1 : T ∣ x 0 ) ∣ ∣ p θ ( x 1 : T ∣ x 0 ) ) = − log p θ ( x 0 ) + E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) / p θ ( x 0 ) ] = − log p θ ( x 0 ) + E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) + log p θ ( x 0 ) ] = E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] = L V L B \begin{aligned} -\log p_\theta(\bold{x}_0) &=-\log p_\theta(\bold{x}_0) + D_{KL}(q(\bold{x}_{1:T}|\bold{x}_0)||p_\theta(\bold{x}_{1:T}|\bold{x}_0))\\ &=-\log p_\theta(\bold{x}_0) +\mathbb{E}_{\bold{x}_{1:T}\sim q(\bold{x}_{1:T}|\bold{x}_0)}\left[\log \frac{q(\bold{x}_{1:T}|\bold{x}_0)}{p_\theta(\bold{x}_{0:T})/p_\theta (\bold{x}_0)}\right]\\ &=-\log p_\theta(\bold{x}_0) +\mathbb{E}_{\bold{x}_{1:T}\sim q(\bold{x}_{1:T}|\bold{x}_0)}\left[\log \frac{q(\bold{x}_{1:T}|\bold{x}_0)}{p_\theta(\bold{x}_{0:T})}+\log p_\theta (\bold{x}_0)\right]\\ &=\mathbb{E}_{\bold{x}_{1:T}\sim q(\bold{x}_{1:T}|\bold{x}_0)}\left[\log \frac{q(\bold{x}_{1:T}|\bold{x}_0)}{p_\theta(\bold{x}_{0:T})}\right]\qquad \color{blue}{=L_{VLB}} \end{aligned} −logpθ(x0)=−logpθ(x0)+DKL(q(x1:T∣x0)∣∣pθ(x1:T∣x0))=−logpθ(x0)+Ex1:T∼q(x1:T∣x0)[logpθ(x0:T)/pθ(x0)q(x1:T∣x0)]=−logpθ(x0)+Ex1:T∼q(x1:T∣x0)[logpθ(x0:T)q(x1:T∣x0)+logpθ(x0)]=Ex1:T∼q(x1:T∣x0)[logpθ(x0:T)q(x1:T∣x0)]=LVLB
我们也可以继续对 L V L B L_{VLB} LVLB进行展开,过程比较繁琐,建议查看论文,最终的形式如下
L V L B = E q [ D K L ( q ( x T ∣ x 0 ) ∣ ∣ p θ ( x T ) ) + ∑ t = 1 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) ] {\color{blue}{L_{VLB}}}=\mathbb{E}_q\left[D_{KL}\left(q(\bold{x}_T|\bold{x}_0)||p_\theta (\bold{x}_T)\right)+{\color{red} \sum_{t=1}^TD_{KL}(q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0)||p_\theta(\bold{x}_{t-1}|\bold{x}_t))}\right] LVLB=Eq[DKL(q(xT∣x0)∣∣pθ(xT))+t=1∑TDKL(q(xt−1∣xt,x0)∣∣pθ(xt−1∣xt))]
其中第一项是不含待优化参数的,仅仅需要优化第二项即可。而且作者将 p θ ( x t − 1 ∣ x t ) p_\theta(\bold{x}_{t-1}|\bold{x}_t) pθ(xt−1∣xt)的方差设置为与 β \beta β有关的常数,可训练参数仅存在其均值中。
我们已经知道 q ( x t − 1 ∣ x t , x 0 ) q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0) q(xt−1∣xt,x0)服从高斯分布,并给出了其均值的表达式,而且知道 p θ ( x t − 1 ∣ x t ) p_\theta(\bold{x}_{t-1}|\bold{x}_t) pθ(xt−1∣xt)也服从高斯分布,其方差设置为常数,仅需优化均值即可。使用文章开头给出的两个单一变量的高斯分布的KL散度表达式,两个分布的方差均为常数,最终的损失函数可以写作两个分布的均值的关系:
L t − 1 = E q [ 1 2 σ t 2 ∣ ∣ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∣ ∣ 2 ] + C {\color{red} L_{t-1}}=\mathbb{E}_q\left[\frac{1}{2\sigma_t^2}||\tilde{\bold\mu}_t(\bold{x}_t,\bold{x}_0)-\mu_\theta(\bold{x}_t,t)||^2\right]+C Lt−1=Eq[2σt21∣∣μ~t(xt,x0)−μθ(xt,t)∣∣2]+C
我们可以将已经得到的 μ t \mu_t μt的表达式,进行简化得到最终的损失函数:
L simple ( θ ) : = E t , x 0 , ϵ [ ∣ ∣ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∣ ∣ 2 ] L_{\text{simple}}(\theta):=\mathbb{E}_{t,\bold{x}_0,\bold\epsilon}\left[||\bold\epsilon-\bold\epsilon_\theta(\sqrt{\bar{\alpha}_t}\bold{x}_0+\sqrt{1-\bar{\alpha}_t}\bold\epsilon,t)||^2\right] Lsimple(θ):=Et,x0,ϵ[∣∣ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)∣∣2]
这里, ϵ θ \epsilon_\theta ϵθ就是可学习的网络,输入 x 0 \bold{x}_0 x0和高斯噪声 ϵ \epsilon ϵ以及时刻 t t t。
优化好网络 ϵ θ \epsilon_\theta ϵθ之后,可以从 x T x_T xT逐步获得 x 0 x_0 x0