本章开始笔者来陆续的介绍最近爆火的Diffusion Model的近期发展。
本篇的学习内容与图片均来自于对文章Diffusion Models: A Comprehensive Survey of Methods and Applications的学习。本篇内容仅代表笔者个人学习观点和笔记,本篇内容的欢迎感兴趣的人一起学习和讨论,也欢迎大家对文章中错误进行纠正和批评。
本篇文章适合刚刚接触Diffusion Model的读者进行初步阅览,避免走一些弯路。本篇属于笔者自己的笔记,因此笔者尽量写的通俗易懂。事实上,如果读者已经学习过变分自编码器(VAE)或生成对抗网络(GAN),那么将相对较为容易的理解 Diffusion Model的原理,因为Diffusion Model利用到了它们的一些思想,而有关VAE/GAN的详细原理介绍,读者可参考其他博客或资料进行学习,如不懂VAE/GAN的新手小白,对于阅读本文而言也是也没有关系的。
本篇文章需要掌握的基础预先了解内容是概率论和稍略简单的随机过程,贝叶斯概率统计等内容即可,本篇文章着重为扩散生成的技术分析,目标在于用最简单的想法来尽可能阐述清楚Diffusion-Model的技术需要一些简单数学公式的推导。
本篇在介绍原理的同时也补充了关于DDPM原文的一些内容,虽然公式看起来会比较复杂,笔者建议读者进行自行推导,看似复杂其实并不是很困难。
在正式学习一个模型前,我们首先要理解该模型的任务是做什么的,首先它作为一个生成模型,是如何生成样本的:Diffision Model通过对样本进行逐步添加Noise破坏来原有样本数据(数据加噪),然后到达一定程度后,再一步一步的通过一个辅助Score-function的将数据一步一步反推生成(降噪),这一过程很像编码器中的Encoder(加噪)和Decoder(降噪)过程,为了方便理解,原文作者还给予了一张图示帮助理解Diffision Model的大致过程(一只狗通过逐步加噪后,根据Score-Function进行逐步降噪,最终生成新的数据)
DDPMS的加噪和去噪是通过两个Markov链完成的。其中加噪链用于将原始数据转换为容易处理的Noise数据,这个分布一般会被取为Normal(正态的),去噪链将采样后的Noise数据转换为新的生成数据。
假设原始数据 x 0 x_0 x0是从某一分布 x 0 ~ q ( x 0 ) x_0~q(x_0) x0~q(x0)中采样得到,若我们已经拥有了一条DDPM forward Markov Chain q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xt∣xt−1)。那么显然可以通过该Markov链生成一条采样大小为 T T T的采样序列,假设为 ( x 0 , x 1 , x 2 , x 3 ⋅ ⋅ ⋅ x T ) (x_0,x_1,x_2,x_3···x_T) (x0,x1,x2,x3⋅⋅⋅xT)。根据Markov链的性质与联合分布概率的性质显然会有该采样的联合分布为
q ( x 0 , x 1 , x 2 , x 3 ⋅ ⋅ ⋅ x T ) = ∏ i = 0 T q ( x t ∣ x t − 1 ) q(x_0,x_1,x_2,x_3···x_T)=\prod \limits_{i=0}^T q(x_t|x_{t-1}) q(x0,x1,x2,x3⋅⋅⋅xT)=i=0∏Tq(xt∣xt−1)
而这里的关键点在于, q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xt∣xt−1)是未知的,即状态转移概率未知,那么在加噪过程中为了让最终加得的噪声能够容易处理,一般我们会选择Normal分布。也即,通常来说这个状态转移概率(加噪链)由人为获取:
q ( x t ∣ x t − 1 ) = N ( ( 1 − β t ) x t − 1 , β t I ) ⟺ x t = ( 1 − β t ) x t − 1 + β t N ( 0 , 1 ) q(x_t|x_{t-1})=N((\sqrt{1-\beta_t})x_{t-1},\beta_tI) \iff x_t=(\sqrt{1-\beta_t})x_{t-1}+\sqrt{\beta_t}N(0,1) q(xt∣xt−1)=N((1−βt)xt−1,βtI)⟺xt=(1−βt)xt−1+βtN(0,1)
β ∈ ( 0 , 1 ) \beta \in (0,1) β∈(0,1)为超参数,训练前给定。这个过程比较容易理解为:每一步加噪的时候,其中一部分来源于原始数据,另一部分来源于一个正态分布的噪声。
并且它还有另一个好处(重点),令 a t = 1 − β t a_t=1-\beta_t at=1−βt,有如下结论:
q ( x k ∣ x 0 ) = N ( ( ∏ i = 0 k a i ) x 0 , ( 1 − ∏ i = 0 k a i ) I ) q(x_k|x_0)=N((\prod \limits_{i=0}^k \sqrt{a_i})x_0,(1-\prod \limits_{i=0}^k a_i)I) q(xk∣x0)=N((i=0∏kai)x0,(1−i=0∏kai)I)
原文并没给出该公式的具体推导办法,这里我给出一种非常证明办法,列举如下:
证明:
现在假设 T = 2 T=2 T=2,很容易可知道如下结果:
x 0 = x 0 x_0=x_0 x0=x0
x 1 = ( 1 − β 1 ) x 0 + β 1 z 0 x_1=(\sqrt{1-\beta_1})x_0+\sqrt{\beta_1}z_0 x1=(1−β1)x0+β1z0
x 2 = ( 1 − β 2 ) x 1 + β 2 z 1 = ( 1 − β 2 ) ( ( 1 − β 1 ) x 0 + β 1 z 0 ) + β 2 z 1 x_2=(\sqrt{1-\beta_2})x_1+\sqrt{\beta_2}z_1=(\sqrt{1-\beta_2})((\sqrt{1-\beta_1})x_0+\sqrt{\beta_1}z_0)+\sqrt{\beta_2}z_1 x2=(1−β2)x1+β2z1=(1−β2)((1−β1)x0+β1z0)+β2z1
x 2 x_2 x2改写,这即:
x 2 = ( 1 − β 2 ) ( 1 − β 1 ) x 0 + ( 1 − β 2 ) β 1 z 0 + β 2 z 1 x_2=(\sqrt{1-\beta_2})(\sqrt{1-\beta_1})x_0+(\sqrt{1-\beta_2})\sqrt{\beta_1}z_0+\sqrt{\beta_2}z_1 x2=(1−β2)(1−β1)x0+(1−β2)β1z0+β2z1
那么显然我们可以知道 x 2 x_2 x2的分布,这里需要用到两次加的随机正态噪声 z 0 , z 1 z_0,z_1 z0,z1是独立的,基于正态分布的性质:
q ( x 2 ∣ x 0 ) = N ( ( 1 − β 1 ) ( 1 − β 2 ) x 0 , ( β 1 + β 2 − β 1 β 2 ) I ) q(x_2|x_0)=N((\sqrt{1-\beta_1})(\sqrt{1-\beta_2})x_0,(\beta_1+\beta_2-\beta_1\beta_2)I) q(x2∣x0)=N((1−β1)(1−β2)x0,(β1+β2−β1β2)I)
注意到,该式可重写为
q ( x 2 ∣ x 0 ) = N ( ( 1 − β 1 ) ( 1 − β 2 ) x 0 , ( 1 − ( 1 − β 1 ) ( 1 − β 2 ) ) I ) q(x_2|x_0)=N((\sqrt{1-\beta_1})(\sqrt{1-\beta_2})x_0,(1-(1-\beta_1)(1-\beta_2))I) q(x2∣x0)=N((1−β1)(1−β2)x0,(1−(1−β1)(1−β2))I)
根据数学归纳法同理可证,令 a t = 1 − β t a_t=1-\beta_t at=1−βt:会有
q ( x k ∣ x 0 ) = N ( ( ∏ i = 0 k a i ) x 0 , ( 1 − ∏ i = 0 k a i ) I ) q(x_k|x_0)=N((\prod \limits_{i=0}^k \sqrt{a_i})x_0,(1-\prod \limits_{i=0}^k a_i)I) q(xk∣x0)=N((i=0∏kai)x0,(1−i=0∏kai)I)
证毕。
该公式给予了一个好处,我们无须真正的进行一步一步的采样,当在训练前设定好所有的超参数 β k ( k = 1 ⋅ ⋅ ⋅ T ) \beta_k(k=1···T) βk(k=1⋅⋅⋅T)根据该公式可直接进行每个步骤 k k k的采样,其中步骤 k k k时刻的采样数据 x k x_k xk服从如下分布:
N ( ( ∏ i = 0 k a i ) x 0 , ( 1 − ∏ i = 0 k a i ) I ) N((\prod \limits_{i=0}^k \sqrt{a_i})x_0,(1-\prod \limits_{i=0}^k {a_i})I) N((i=0∏kai)x0,(1−i=0∏kai)I)
注意到, ( ∏ i = 0 k a k ) (\prod \limits_{i=0}^k \sqrt{a_k}) (i=0∏kak)此项随着采样步骤的增多这个值会越来越小,这是因为都是小于1的数进行不断连乘,当加噪步骤足够多,趋向于无穷大时,此时该分布趋向于 N ( 0 , 1 ) N(0,1) N(0,1)。即当 k → ∞ k \rightarrow ∞ k→∞,这很容易理解,这也即该采样样本的边缘分布。
lim k → ∞ ( N ( ( ∏ i = 0 k a k ) x 0 , ( 1 − ∏ i = 0 k a k ) I ) ) = N ( 0 , 1 ) \lim_{k \rightarrow ∞}(N((\prod \limits_{i=0}^k \sqrt{a_k})x_0,(1-\prod \limits_{i=0}^k {a_k})I))=N(0,1) k→∞lim(N((i=0∏kak)x0,(1−i=0∏kak)I))=N(0,1)
lim k → ∞ q ( x k ) = ∫ q ( x k ∣ x 0 ) q ( x 0 ) d ( x 0 ) = N ( 0 , 1 ) \lim_{k \rightarrow ∞}q(x_k)=\int q(x_k|x_0)q(x_0)d(x_0)=N(0,1) k→∞limq(xk)=∫q(xk∣x0)q(x0)d(x0)=N(0,1)
当进行足够多的加噪步骤,设已经加了 T T T步噪后,此时根据我们2.1.1的讨论,该噪声已经近似分布为 N ( 0 , 1 ) N(0,1) N(0,1)。此时需要进行去噪——生成处理,而这一步骤需要用到所谓的去噪链来进行。
即,我们已经拥有了一个先验分布: x T ~ N ( 0 , 1 ) x_T~N(0,1) xT~N(0,1)。即我们已知了最终噪声分布,若我们知道一个去噪链 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt−1∣xt),那么就可以根据该去噪链“一条一条”的进行去噪处理了。注意到,我们假设 T T T足够大时候,有这样的等式存在,这是显然的:
p ( x T ) = q ( x T ) ~ N ( 0 , 1 ) p(x_T)=q(x_T)~N(0,1) p(xT)=q(xT)~N(0,1)
即我们通过加噪得到了最终的噪声分布,作为降噪前的先验分布,进行降噪处理。
该操作我们使用神经网络来进行,即假设去噪链为一个网络 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt−1∣xt),网络参数为 θ \theta θ,其中去噪网络满足如下分布
x t − 1 ~ N ( μ θ ( x t , t ) , Σ θ ( x t , t ) ) x_{t-1}~N(\mu_ \theta(x_t,t),\Sigma_\theta(x_t,t)) xt−1~N(μθ(xt,t),Σθ(xt,t))显然的,这需要两个网络 μ θ ( x t , t ) \mu_\theta(x_t,t) μθ(xt,t)与 Σ θ ( x t , t ) \Sigma_\theta(x_t,t) Σθ(xt,t),它们接受 t 步骤 t步骤 t步骤的输入 x t x_t xt与第 t t t步骤的位置信息,反馈一个正态分布,而 x t − 1 x_{t-1} xt−1从该分布中进行采样。
下面我用一个非常简单易懂的例子来总结一下DDPM的整体框架:
[ x 0 → x 1 → x 2 → ⋅ ⋅ ⋅ → x T − 1 ] → x T → [ x T − 1 ^ → ⋅ ⋅ ⋅ → x 2 ^ → x 1 ^ → x 0 ^ ] [x_0 \rightarrow x_1 \rightarrow x_2 \rightarrow···\rightarrow x_{T-1}] \rightarrow x_T \rightarrow [\hat{x_{T-1}} \rightarrow ···\rightarrow \hat{x_{2}}\rightarrow \hat{x_{1}}\rightarrow \hat{x_{0}}] [x0→x1→x2→⋅⋅⋅→xT−1]→xT→[xT−1^→⋅⋅⋅→x2^→x1^→x0^]
其中 x 0 , x 1 , x 2 , x 3 , ⋅ ⋅ ⋅ x T − 1 , x T x_0,x_1,x_2,x_3,···x_{T-1},x_T x0,x1,x2,x3,⋅⋅⋅xT−1,xT通过forward Markov Chain 采样获得(2.1.1), x 0 ^ , x 1 ^ , x 2 ^ , x 3 ^ , ⋅ ⋅ ⋅ x T − 1 ^ , x T ^ \hat{x_0},\hat{x_1},\hat{x_2},\hat{x_3},···\hat{x_{T-1}},\hat{x_T} x0^,x1^,x2^,x3^,⋅⋅⋅xT−1^,xT^通过reverse Markov Chain采样获得。注意到 x T ^ = x T \hat{x_T}=x_T xT^=xT,这一点我在上面已经做出过解释了。
DDPM的目的是去尽可能还原原来的噪声分布,注意,这里提到的是“分布还原”而不是数据还原,DDPM的目的是去学习一个这样加噪过程的逆过程。因此,目标很显然了,我们想要让 ( x 0 , x 1 , x 2 , x 3 , ⋅ ⋅ ⋅ x T − 1 , x T ) (x_0,x_1,x_2,x_3,···x_{T-1},x_T) (x0,x1,x2,x3,⋅⋅⋅xT−1,xT)与 ( x 0 ^ , x 1 ^ , x 2 ^ , x 3 ^ , ⋅ ⋅ ⋅ x T − 1 , x T ^ ^ ) (\hat{x_0},\hat{x_1},\hat{x_2},\hat{x_3},···\hat{x_{T-1},\hat{x_T}}) (x0^,x1^,x2^,x3^,⋅⋅⋅xT−1,xT^^)分布尽可能相似:
q ∗ = q ( x 0 , x 1 , x 2 , x 3 , ⋅ ⋅ ⋅ x T − 1 , x T ) = q ( x 0 ) ∏ i = 1 T q ( x i ∣ x i − 1 ) q^{*}=q(x_0,x_1,x_2,x_3,···x_{T-1},x_T)=q(x_0)\prod \limits_{i=1}^Tq(x_i|x_{i-1}) q∗=q(x0,x1,x2,x3,⋅⋅⋅xT−1,xT)=q(x0)i=1∏Tq(xi∣xi−1)
p θ ∗ = p θ ( x 0 ^ , x 1 ^ , x 2 ^ , x 3 ^ , ⋅ ⋅ ⋅ x T − 1 , x T ^ ^ ) = p ( x T ^ ) ∏ i = 1 T p θ ( x i − 1 ^ ∣ x i ^ ) p_{\theta}^{*}=p_{\theta}(\hat{x_0},\hat{x_1},\hat{x_2},\hat{x_3},···\hat{x_{T-1},\hat{x_T}})=p(\hat{x_T})\prod \limits_{i=1}^Tp_{\theta}(\hat{x_{i-1}}|\hat{x_{i}}) pθ∗=pθ(x0^,x1^,x2^,x3^,⋅⋅⋅xT−1,xT^^)=p(xT^)i=1∏Tpθ(xi−1^∣xi^)
而我们知道,衡量两个分布的相似度是使用相对熵(KL散度)来构建。为了使上述两个联合分布相近,需要最优化它们的KL散度使其达到最小值,这也即网络的Loss Function:
L o s s = K L ( [ q ∗ ∣ ∣ p θ ∗ ] ) = ∑ q ∗ l o g ( q ∗ p θ ∗ ) = ∑ q ∗ l o g ( q ∗ ) − ∑ q ∗ l o g ( p θ ∗ ) Loss=KL([q^*||p_{\theta}^*])=\sum q^*log(\frac{q^*}{p_{\theta}^*})=\sum q^*log(q^*)-\sum q^*log(p_{\theta}^*) Loss=KL([q∗∣∣pθ∗])=∑q∗log(pθ∗q∗)=∑q∗log(q∗)−∑q∗log(pθ∗)
事实上,由于我们是知道 q ∗ q* q∗的具体表达式的,因为加噪过程是人为给定好的,令 ∑ q ∗ l o g ( q ∗ ) = K \sum q^*log(q^*)=K ∑q∗log(q∗)=K,我们会有:
L o s s = − ∑ q ∗ l o g ( p θ ∗ ) + K = − E q ∗ [ l o g ( p θ ∗ ) ] + K Loss=-\sum q^*log(p_{\theta}^*)+K=-E_{q^*}[log(p_{\theta}^*)]+K Loss=−∑q∗log(pθ∗)+K=−Eq∗[log(pθ∗)]+K
这显然是没办法直接进行计算的,需要进行进一步的整理,我们将 p θ ∗ p_{\theta}^* pθ∗的具体形式带入进来,并且根据 ∑ q ∗ l o g ( ∑ i = 1 T q ( x i ∣ x i − 1 ) ) = K − q ∗ q ( x 0 ) = M \sum q^*log(\sum \limits_{i=1}^T q(x_i|x_{i-1}))=K-q^*q(x_0)=M ∑q∗log(i=1∑Tq(xi∣xi−1))=K−q∗q(x0)=M。会发现:
L o s s = − E q ∗ [ l o g ( p ( x T ^ ) ) + ∑ i = 1 T l o g p θ ( x i − 1 ^ ∣ x i ^ ) q ( x i ∣ x i − 1 ) ] + K + M ≥ − E q ∗ ( p θ ( x 0 ) ) Loss=-E_{q^*}[log(p(\hat{x_T}))+\sum \limits_{i=1}^Tlog\frac{p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})}{q(x_i|x_{i-1})}]+K+M \ge -E_{q^*}(p_\theta(x_0)) Loss=−Eq∗[log(p(xT^))+i=1∑Tlogq(xi∣xi−1)pθ(xi−1^∣xi^)]+K+M≥−Eq∗(pθ(x0))
(该不等式的证明下面一小部分分析内容不感兴趣的读者可以不看)
原文并未给出下界的证明,但是给予了提示,这里我补充一下。我们都知道, K L KL KL散度是恒大于零的,证明办法采用了Jenson不等式,这里采用相同的办法。
( J e n s o n − i n e q u a l i t y ) (Jenson-inequality) (Jenson−inequality): f ( x ) f(x) f(x)为凸(凹)的, λ i > 0 \lambda_i>0 λi>0且 ∑ i = 1 N λ i = 1 \sum_{i=1}^N\lambda_i=1 ∑i=1Nλi=1。则有
f ( ∑ i = 1 N λ i x i ) ≤ ( ≥ ) ∑ i = 1 N λ i f ( x i ) f(\sum_{i=1}^N\lambda_ix_i) \le(\ge) \sum_{i=1}^N\lambda_if(x_i) f(i=1∑Nλixi)≤(≥)i=1∑Nλif(xi)
我们利用凹性,因为log函数显然是凹的,使用Jenson不等式:
L o s s ≥ − E q ∗ ( ∑ i = 1 T l o g p θ ( x i − 1 ^ ∣ x i ^ ) q ( x i ∣ x i − 1 ) ) ≥ − l o g ( ∑ k = 1 M ∑ i = 1 T q ( x 0 k ) ⋅ ⋅ ⋅ q ( x T k ∣ x T − 1 k ) p θ ( x i − 1 ^ k ∣ x i ^ k ) q ( x i k ∣ x i − 1 k ) ) Loss \ge -E_{q^*}(\sum \limits_{i=1}^Tlog\frac{p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})}{q(x_i|x_{i-1})})\ge -log(\sum_{k=1}^M\sum_{i=1}^Tq(x_0^k)···q(x_T^k|x_{T-1}^k)\frac{p_{\theta}(\hat{x_{i-1}}^k|\hat{x_{i}}^k)}{q(x_i^k|x_{i-1}^k)}) Loss≥−Eq∗(i=1∑Tlogq(xi∣xi−1)pθ(xi−1^∣xi^))≥−log(k=1∑Mi=1∑Tq(x0k)⋅⋅⋅q(xTk∣xT−1k)q(xik∣xi−1k)pθ(xi−1^k∣xi^k))
− l o g ( ∑ k = 1 M ∑ i = 1 T q ( x 0 k ) ⋅ ⋅ ⋅ q ( x T k ∣ x T − 1 k ) p θ ( x i − 1 ^ k ∣ x i ^ k ) q ( x i k ∣ x i − 1 k ) ) ≥ − l o g ( ∑ k = 1 M ∑ i = 1 T q ∗ ( k ) p θ ( x i − 1 ^ k ∣ x i ^ k ) ) -log(\sum_{k=1}^M\sum_{i=1}^Tq(x_0^k)···q(x_T^k|x_{T-1}^k)\frac{p_{\theta}(\hat{x_{i-1}}^k|\hat{x_{i}}^k)}{q(x_i^k|x_{i-1}^k)})\ge -log(\sum_{k=1}^M\sum_{i=1}^Tq^{*(k)}p_{\theta}(\hat{x_{i-1}}^k|\hat{x_{i}}^k)) −log(k=1∑Mi=1∑Tq(x0k)⋅⋅⋅q(xTk∣xT−1k)q(xik∣xi−1k)pθ(xi−1^k∣xi^k))≥−log(k=1∑Mi=1∑Tq∗(k)pθ(xi−1^k∣xi^k))
再使用一次Jenson不等式:
− l o g ( ∑ k = 1 M ∑ i = 1 T q ∗ ( k ) p θ ( x i − 1 ^ k ∣ x i ^ k ) ) ≥ − ∑ k = 1 M q ∗ ( k ) ∑ i = 1 T l o g ( p θ ( x i − 1 ^ k ∣ x i ^ k ) ) -log(\sum_{k=1}^M\sum_{i=1}^Tq^{*(k)}p_{\theta}(\hat{x_{i-1}}^k|\hat{x_{i}}^k)) \ge -\sum_{k=1}^Mq^{*(k)}\sum_{i=1}^Tlog(p_{\theta}(\hat{x_{i-1}}^k|\hat{x_{i}}^k)) −log(k=1∑Mi=1∑Tq∗(k)pθ(xi−1^k∣xi^k))≥−k=1∑Mq∗(k)i=1∑Tlog(pθ(xi−1^k∣xi^k))
这即得到了原文中给的不等式:
L o s s ≥ − ∑ k = 1 M q ∗ ( k ) ∑ i = 1 T l o g ( p θ ( x i − 1 ^ k ∣ x i ^ k ) ) = − E q ∗ ( − l o g ( p θ ( x 0 ) ) ) Loss \ge -\sum_{k=1}^Mq^{*(k)}\sum_{i=1}^Tlog(p_{\theta}(\hat{x_{i-1}}^k|\hat{x_{i}}^k))=-E_{q^*}(-log(p_\theta(x_0))) Loss≥−k=1∑Mq∗(k)i=1∑Tlog(pθ(xi−1^k∣xi^k))=−Eq∗(−log(pθ(x0)))
观察到,我们上面已经讨论过了,上面有这个等式的存在
L o s s = − E q ∗ [ l o g ( p ( x T ^ ) ) + ∑ i = 1 T l o g p θ ( x i − 1 ^ ∣ x i ^ ) q ( x i ∣ x i − 1 ) ] + K + M Loss=-E_{q^*}[log(p(\hat{x_T}))+\sum \limits_{i=1}^Tlog\frac{p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})}{q(x_i|x_{i-1})}]+K+M Loss=−Eq∗[log(p(xT^))+i=1∑Tlogq(xi∣xi−1)pθ(xi−1^∣xi^)]+K+M
其中, K , M K,M K,M为常数,也即不与网络参数 θ \theta θ有关的恒常值。那么当然它不影响Loss的优化,那么我们的目标是寻找使得下式最大的参数 θ \theta θ以获得:
θ = a r g m a x θ ( E q ∗ [ l o g ( p ( x T ^ ) ) + ∑ i = 1 T l o g p θ ( x i − 1 ^ ∣ x i ^ ) q ( x i ∣ x i − 1 ) ] ) \theta = argmax_{\theta}(E_{q^*}[log(p(\hat{x_T}))+\sum \limits_{i=1}^Tlog\frac{p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})}{q(x_i|x_{i-1})}]) θ=argmaxθ(Eq∗[log(p(xT^))+i=1∑Tlogq(xi∣xi−1)pθ(xi−1^∣xi^)])
L o s s = − ( E q ∗ [ l o g ( p ( x T ^ ) ) + ∑ i = 1 T l o g p θ ( x i − 1 ^ ∣ x i ^ ) q ( x i ∣ x i − 1 ) ] ) —— ( 1 ∗ ) Loss=-(E_{q^*}[log(p(\hat{x_T}))+\sum \limits_{i=1}^Tlog\frac{p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})}{q(x_i|x_{i-1})}])——(1^*) Loss=−(Eq∗[log(p(xT^))+i=1∑Tlogq(xi∣xi−1)pθ(xi−1^∣xi^)])——(1∗)
其中,该过程通过Monte Carlo模拟来进行即:正向采样——反向采样——计算Loss。
Monte Carlo模拟虽然可以估计该Loss大小,虽然能保证是无偏的(一定程度上),但是它有个巨大的缺陷和致命弱点:它具有极高方差性,这显然对于模型是非常不利的,那么我们要平稳的训练一个模型,需要改写上述的 ( 1 ∗ ) (1^*) (1∗)来获得一个低方差的公式,因此DDPM核心在于2.1.3.2的低方差更新方式。
根据 ( 1 ∗ ) (1^*) (1∗)我们已经知道了它具有很强的随机方差性,下面若能解析性的改写 ( 1 ∗ ) (1^*) (1∗),目的为降低计算和采样方差。
L o s s = − ( E q ∗ [ l o g ( p ( x T ^ ) ) + ∑ i = 1 T l o g p θ ( x i − 1 ^ ∣ x i ^ ) q ( x i ∣ x i − 1 ) ] ) —— ( 1 ∗ ) Loss=-(E_{q^*}[log(p(\hat{x_T}))+\sum \limits_{i=1}^Tlog\frac{p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})}{q(x_i|x_{i-1})}])——(1^*) Loss=−(Eq∗[log(p(xT^))+i=1∑Tlogq(xi∣xi−1)pθ(xi−1^∣xi^)])——(1∗)
首先它可被等价为如下:
L o s s = − ( E q ∗ [ l o g ( p ( x T ^ ) ) + ∑ i = 2 T l o g p θ ( x i − 1 ^ ∣ x i ^ ) q ( x i ∣ x i − 1 ) + l o g ( p θ ( x 0 ^ ∣ x 1 ^ ) q ( x 1 ∣ x 0 ) ) ] ) ( 2 ∗ ) Loss=-(E_{q^*}[log(p(\hat{x_T}))+\sum \limits_{i=2}^Tlog\frac{p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})}{q(x_i|x_{i-1})}+log(\frac{p_{\theta}(\hat{x_{0}}|\hat{x_{1}})}{q(x_1|x_{0})})])(2^*) Loss=−(Eq∗[log(p(xT^))+i=2∑Tlogq(xi∣xi−1)pθ(xi−1^∣xi^)+log(q(x1∣x0)pθ(x0^∣x1^))])(2∗)
根据贝叶斯公式我们将 q ( x i ∣ x i − 1 ) q(x_i|x_{i-1}) q(xi∣xi−1)视为后验分布,则会有
q ( x i ∣ x i − 1 ) = q ( x i − 1 ∣ x i , x 0 ) q ( x i ) q ( x i − 1 ) = q ( x i − 1 ∣ x i , x 0 ) q ( x i ∣ x 0 ) q ( x i − 1 ∣ x 0 ) q(x_i|x_{i-1})=\frac{q(x_{i-1}|x_i,x_0)q(x_i)}{q(x_{i-1})}=\frac{q(x_{i-1}|x_i,x_0)q(x_i|x_0)}{q(x_{i-1}|x_0)} q(xi∣xi−1)=q(xi−1)q(xi−1∣xi,x0)q(xi)=q(xi−1∣x0)q(xi−1∣xi,x0)q(xi∣x0)
将上述公式带入到 ( 2 ∗ ) (2^*) (2∗)中会得到
L o s s = − ( E q ∗ [ l o g ( p ( x T ^ ) ) + ∑ i = 2 T l o g p θ ( x i − 1 ^ ∣ x i ^ ) q ( x i − 1 ∣ x i , x 0 ) q ( x i − 1 ∣ x 0 ) q ( x i ∣ x 0 ) + l o g ( p θ ( x 0 ^ ∣ x 1 ^ ) q ( x 1 ∣ x 0 ) ) ] ) ( 3 ∗ ) Loss=-(E_{q^*}[log(p(\hat{x_T}))+\sum \limits_{i=2}^Tlog\frac{p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})}{q(x_{i-1}|x_i,x_0)}\frac{q(x_{i-1}|x_0)}{q(x_i|x_0)}+log(\frac{p_{\theta}(\hat{x_{0}}|\hat{x_{1}})}{q(x_1|x_{0})})])(3^*) Loss=−(Eq∗[log(p(xT^))+i=2∑Tlogq(xi−1∣xi,x0)pθ(xi−1^∣xi^)q(xi∣x0)q(xi−1∣x0)+log(q(x1∣x0)pθ(x0^∣x1^))])(3∗)
观察一下有一项,可以提出来即:
∑ i = 2 T l o g ( q ( x i − 1 ∣ x 0 ) q ( x i ∣ x 0 ) ) = l o g ( q ( x 1 ∣ x 0 ) ) − l o g ( q ( x T ∣ x 0 ) ) \sum \limits_{i=2}^Tlog(\frac{q(x_{i-1}|x_0)}{q(x_i|x_0)})=log(q(x_1|x_0))-log(q(x_T|x_0)) i=2∑Tlog(q(xi∣x0)q(xi−1∣x0))=log(q(x1∣x0))−log(q(xT∣x0))
( 3 ∗ ) (3^*) (3∗)可被改写为如下简单的式子,其中会有 x T ^ = x T \hat{x_T}=x_T xT^=xT,这在上面已经声明过了:
L o s s = − ( E q ∗ [ l o g p ( x T ) q ( x T ∣ x 0 ) + ∑ i = 2 T l o g p θ ( x i − 1 ^ ∣ x i ^ ) q ( x i − 1 ∣ x i , x 0 ) + l o g ( p θ ( x 0 ^ ∣ x 1 ^ ) ] ) ( 4 ∗ ) Loss=-(E_{q^*}[log\frac{p({x_T})}{q(x_T|x_0)}+\sum \limits_{i=2}^Tlog\frac{p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})}{q(x_{i-1}|x_i,x_0)}+log(p_{\theta}(\hat{x_{0}}|\hat{x_{1}})])(4^*) Loss=−(Eq∗[logq(xT∣x0)p(xT)+i=2∑Tlogq(xi−1∣xi,x0)pθ(xi−1^∣xi^)+log(pθ(x0^∣x1^)])(4∗)
事实上,结合 K L KL KL散度定义我们知道这即等价于下式:
E q ∗ [ K L [ q ( x T ∣ x 0 ) ∣ ∣ p ( x T ) ] + ∑ i = 2 T K L [ q ( x i − 1 ∣ x i , x 0 ) ∣ ∣ p θ ( x i − 1 ^ ∣ x i ^ ) ] − l o g ( p θ ( x 0 ^ ∣ x 1 ^ ) ] E_{q^*}[KL[q(x_T|x_0)||p(x_T)]+\sum \limits_{i=2}^TKL[q(x_{i-1}|x_i,x_0)||p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})]-log(p_{\theta}(\hat{x_{0}}|\hat{x_{1}})] Eq∗[KL[q(xT∣x0)∣∣p(xT)]+i=2∑TKL[q(xi−1∣xi,x0)∣∣pθ(xi−1^∣xi^)]−log(pθ(x0^∣x1^)]
令 K L [ q ( x i − 1 ∣ x i , x 0 ) ∣ ∣ p θ ( x i − 1 ^ ∣ x i ^ ) ] = L i − 1 KL[q(x_{i-1}|x_i,x_0)||p_{\theta}(\hat{x_{i-1}}|\hat{x_{i}})]=L_{i-1} KL[q(xi−1∣xi,x0)∣∣pθ(xi−1^∣xi^)]=Li−1
L o s s = E q ∗ [ L T + ∑ i = 2 T L i − 1 − L 0 ] = E q ∗ [ ∑ i = 1 T L i − L 0 ] ( 5 ∗ ) Loss=E_{q^*}[L_T+\sum \limits_{i=2}^TL_{i-1}-L_0]=E_{q^*}[\sum \limits_{i=1}^TL_{i}-L_0](5^*) Loss=Eq∗[LT+i=2∑TLi−1−L0]=Eq∗[i=1∑TLi−L0](5∗)
但是这缺少一个问题, q ( x i − 1 ∣ x i , x 0 ) q(x_{i-1}|x_i,x_0) q(xi−1∣xi,x0)的分布我们目前还不知道,如何计算?
下面给予 q ( x i − 1 ∣ x i , x 0 ) q(x_{i-1}|x_i,x_0) q(xi−1∣xi,x0)分布计算过程:
根据上述已经提到过的贝叶斯公式,已经知道了如下结论
q ( x i ∣ x i − 1 ) = q ( x i − 1 ∣ x i , x 0 ) q ( x i ∣ x 0 ) q ( x i − 1 ∣ x 0 ) → q ( x i − 1 ∣ x i , x 0 ) = q ( x i ∣ x i − 1 ) q ( x i − 1 ∣ x 0 ) q ( x i ∣ x 0 ) q(x_i|x_{i-1})=\frac{q(x_{i-1}|x_i,x_0)q(x_i|x_0)}{q(x_{i-1}|x_0)} \rightarrow q(x_{i-1}|x_i,x_0)=\frac{q(x_i|x_{i-1})q(x_{i-1}|x_0)}{q(x_i|x_0)} q(xi∣xi−1)=q(xi−1∣x0)q(xi−1∣xi,x0)q(xi∣x0)→q(xi−1∣xi,x0)=q(xi∣x0)q(xi∣xi−1)q(xi−1∣x0)
事实上,这三个分布我们都是知道的,即:
q ( x i ∣ x i − 1 ) ~ N ( a i x i − 1 , ( 1 − a i ) I ) q(x_i|x_{i-1})~N(\sqrt{a_i}x_{i-1},(1-a_i)I) q(xi∣xi−1)~N(aixi−1,(1−ai)I)
q ( x i − 1 ∣ x 0 ) ~ N ( ( ∏ k = 0 i − 1 a k ) x 0 , ( 1 − ∏ k = 0 i − 1 a k ) I ) q(x_{i-1}|x_0)~N((\prod \limits_{k=0}^{i-1} \sqrt{a_k})x_0,(1-\prod \limits_{k=0}^{i-1} a_k)I) q(xi−1∣x0)~N((k=0∏i−1ak)x0,(1−k=0∏i−1ak)I)
q ( x i ∣ x 0 ) ~ N ( ( ∏ k = 0 i a k ) x 0 , ( 1 − ∏ k = 0 i a k ) I ) q(x_{i}|x_0)~N((\prod \limits_{k=0}^{i} \sqrt{a_k})x_0,(1-\prod \limits_{k=0}^{i} {a_k})I) q(xi∣x0)~N((k=0∏iak)x0,(1−k=0∏iak)I)
根据贝叶斯统计理论知道一个非常重要的结论,先验分布 q ( x i ∣ x i − 1 ) q(x_i|x_{i-1}) q(xi∣xi−1)是正态分布,那我们所求的后验分布 q ( x i − 1 ∣ x i , x 0 ) q(x_{i-1}|x_i,x_0) q(xi−1∣xi,x0)一定也是正态分布:
我们忽略常数项,可知原分布正比于下面的公式即: q ( x i − 1 ∣ x i , x 0 ) ∝ q(x_{i-1}|x_i,x_0)\propto q(xi−1∣xi,x0)∝:
e x p ( − 1 2 [ ( x i − a i x i − 1 ) 2 1 − a i + ( x i − 1 − ∏ k = 0 i − 1 a k x 0 ) 2 1 − ∏ k = 0 i − 1 a k + ( x i − ∏ k = 0 i a k x 0 ) 2 1 − ∏ k = 0 i a k ] ) ( ∗ ∗ ) exp(-\frac{1}{2}[\frac{(x_i-\sqrt{a_i}x_{i-1})^2}{1-a_i}+\frac{(x_{i-1}-\sqrt{\prod \limits_{k=0}^{i-1} a_k}x_0)^2}{1-\prod \limits_{k=0}^{i-1} a_k}+\frac{(x_{i}-\sqrt{\prod \limits_{k=0}^{i} a_k}x_0)^2}{1-\prod \limits_{k=0}^{i} a_k}])(**) exp(−21[1−ai(xi−aixi−1)2+1−k=0∏i−1ak(xi−1−k=0∏i−1akx0)2+1−k=0∏iak(xi−k=0∏iakx0)2])(∗∗)
由于已经知道了后验分布一定为正态分布,那么即可直接获得其中与 x i − 1 x_{i-1} xi−1有关的部分
其他放成常数倍落在后面。
e x p ( − 1 2 [ ( a i 1 − a i + 1 1 − ∏ k = 0 i − 1 a k ) x i − 1 2 − ( 2 a i 1 − a i x i + 2 ∏ k = 0 i − 1 a k 1 − ∏ k = 0 i − 1 a k x 0 ) x i − 1 ] ) + C exp(-\frac{1}{2}[(\frac{a_i}{1-a_i}+\frac{1}{1-\prod \limits_{k=0}^{i-1} a_k})x_{i-1}^2-(\frac{2\sqrt{a_i}}{1-a_i}x_i+\frac{2\sqrt{\prod \limits_{k=0}^{i-1} a_k}}{1-\prod \limits_{k=0}^{i-1} a_k}x_0)x_{i-1}])+C exp(−21[(1−aiai+1−k=0∏i−1ak1)xi−12−(1−ai2aixi+1−k=0∏i−1ak2k=0∏i−1akx0)xi−1])+C
最后配方即可获得了后验分布的表达式
q ( x i − 1 ∣ x i , x 0 ) ~ N ( μ , σ 2 I ) q(x_{i-1}|x_i,x_0)~N(\mu,\sigma^2 I) q(xi−1∣xi,x0)~N(μ,σ2I)
其中:
μ i = ( a i 1 − a i x i + ∏ k = 0 i − 1 a k 1 − ∏ k = 0 i − 1 a k x 0 ) / ( a i 1 − a i + 1 1 − ∏ k = 0 i − 1 a k ) \mu_i=(\frac{\sqrt{a_i}}{1-a_i}x_i+\frac{\sqrt{\prod \limits_{k=0}^{i-1} a_k}}{1-\prod \limits_{k=0}^{i-1} a_k}x_0) /(\frac{a_i}{1-a_i}+\frac{1}{1-\prod \limits_{k=0}^{i-1} a_k}) μi=(1−aiaixi+1−k=0∏i−1akk=0∏i−1akx0)/(1−aiai+1−k=0∏i−1ak1)
σ i 2 = 1 / ( a i 1 − a i + 1 1 − ∏ k = 0 i − 1 a k ) \sigma_i^2=1/(\frac{a_i}{1-a_i}+\frac{1}{1-\prod \limits_{k=0}^{i-1} a_k}) σi2=1/(1−aiai+1−k=0∏i−1ak1)
这样 ( 5 ∗ ) (5^*) (5∗)中所有变量均有了它们的表达式,并且 K L ( ⋅ | ⋅ ) KL(·|·) KL(⋅|⋅)这里计算的都是两个正态分布之间的差异性。
DDPM由以下几点重要部分构成:
一、人为设定的加噪链 q q q( β \beta β足够小让最终噪声分布趋近正态分布)
二、网络设定的去噪链 p θ p_\theta pθ
三、KL低方差损失函数
现在一、二我们已经介绍完了,下面要对三进行具体的更新和采样了。
注意到 L T = K L [ q ( x T ∣ x 0 ) ∣ ∣ p ( x T ) ] L_T=KL[q(x_T|x_0)||p(x_T)] LT=KL[q(xT∣x0)∣∣p(xT)]是给定的,因为显然地,若满足上述三个条件,那么 p ( x T ) ~ N ( 0 , 1 ) p(x_T)~N(0,1) p(xT)~N(0,1),那么其实训练过程我们没有必要关注 L T L_T LT部分,因为不影响模型参数结果。只需要优化一下的新目标函数:
L o s s ∗ = E q ∗ [ ∑ i = 2 T L i − 1 − L 0 ] Loss^*=E_{q^*}[\sum \limits_{i=2}^TL_{i-1}-L_0] Loss∗=Eq∗[i=2∑TLi−1−L0]
那么如何确定优化目标?我们的目标是要去模拟加噪链反过来的这条去噪链,如果将加噪链视为一个先验分布,那么我们的目标即要训练一个后验分布,这样就可以反向采样!
而我们已经知道了后验分布满足了
q ( x i − 1 ∣ x i , x 0 ) ~ N ( μ , σ 2 I ) q(x_{i-1}|x_i,x_0)~N(\mu,\sigma^2 I) q(xi−1∣xi,x0)~N(μ,σ2I)
μ i = ( a i 1 − a i x i + ∏ k = 0 i − 1 a k 1 − ∏ k = 0 i − 1 a k x 0 ) / ( a i 1 − a i + 1 1 − ∏ k = 0 i − 1 a k ) \mu_i=(\frac{\sqrt{a_i}}{1-a_i}x_i+\frac{\sqrt{\prod \limits_{k=0}^{i-1} a_k}}{1-\prod \limits_{k=0}^{i-1} a_k}x_0) /(\frac{a_i}{1-a_i}+\frac{1}{1-\prod \limits_{k=0}^{i-1} a_k}) μi=(1−aiaixi+1−k=0∏i−1akk=0∏i−1akx0)/(1−aiai+1−k=0∏i−1ak1)
σ i 2 = 1 / ( a i 1 − a i + 1 1 − ∏ k = 0 i − 1 a k ) \sigma_i^2=1/(\frac{a_i}{1-a_i}+\frac{1}{1-\prod \limits_{k=0}^{i-1} a_k}) σi2=1/(1−aiai+1−k=0∏i−1ak1)
如果我能用一个网络 p θ ( μ θ ( x i , i ) , Σ θ ( x i , i ) ) p_\theta(\mu_\theta(x_i,i),\Sigma_\theta(x_i,i)) pθ(μθ(xi,i),Σθ(xi,i))预测即可,注意到我们要去预测的分布里面 σ \sigma σ是没有有关 x i x_i xi的参数的,是一个常数因此该网络参数可被重写为 p θ ( μ θ ( x i , i ) , σ i 2 I ) p_\theta(\mu_\theta(x_i,i),\sigma_i^2I) pθ(μθ(xi,i),σi2I)。只需要对 μ θ ( x i , i ) \mu_\theta(x_i,i) μθ(xi,i)进行预估即可。
通过上述讨论,我们将 L t − 1 L_{t-1} Lt−1写出来看
L t − 1 = K L [ q ( x t ∣ x t + 1 , x 0 ) ∣ ∣ p θ ( x t ^ ∣ x t + 1 ^ ) ] = K L [ N ( μ t , σ i 2 I ) ∣ ∣ N ( μ θ , σ i 2 I ) ] L_{t-1}=KL[q(x_{t}|x_{t+1},x_0)||p_{\theta}(\hat{x_{t}}|\hat{x_{t+1}})]=KL[N(\mu_t,\sigma_i^2I)||N(\mu_\theta,\sigma_i^2I)] Lt−1=KL[q(xt∣xt+1,x0)∣∣pθ(xt^∣xt+1^)]=KL[N(μt,σi2I)∣∣N(μθ,σi2I)]
进行简要的推导可知道:
K L [ N ( μ t , σ i 2 I ) ∣ ∣ N ( μ θ , σ i 2 I ) ] = ∫ x 1 2 π σ i e ( x − μ t ) 2 2 σ i 2 ( − 1 2 σ i 2 [ ( x − μ θ ) 2 − ( x − μ t ) 2 ] ) d x KL[N(\mu_t,\sigma_i^2I)||N(\mu_\theta,\sigma_i^2I)]=\int_x\frac{1}{\sqrt{2\pi}\sigma_i}e^{\frac{(x-\mu_t)^2}{2\sigma_i^2}}(-\frac{1}{2\sigma_i^2}[(x-\mu_\theta)^2-(x-\mu_t)^2])dx KL[N(μt,σi2I)∣∣N(μθ,σi2I)]=∫x2πσi1e2σi2(x−μt)2(−2σi21[(x−μθ)2−(x−μt)2])dx
其中
∫ x 1 2 π σ i e ( x − μ t ) 2 2 σ i 2 ( 1 2 σ i 2 [ ( x − μ t ) 2 ] ) d x = 1 2 σ i 2 E ( [ x − E ( x ) ] 2 ) = 1 2 σ i 2 D ( x ) = 1 2 \int_x\frac{1}{\sqrt{2\pi}\sigma_i}e^{\frac{(x-\mu_t)^2}{2\sigma_i^2}}(\frac{1}{2\sigma_i^2}[(x-\mu_t)^2])dx=\frac{1}{2\sigma_i^2}E([x-E(x)]^2)=\frac{1}{2\sigma_i^2}D(x)=\frac{1}{2} ∫x2πσi1e2σi2(x−μt)2(2σi21[(x−μt)2])dx=2σi21E([x−E(x)]2)=2σi21D(x)=21
∫ x 1 2 π σ i e ( x − μ t ) 2 2 σ i 2 ( − 1 2 σ i 2 [ ( x − μ θ ) 2 ] ) d x = σ i 2 + ∣ ∣ μ θ − μ t ∣ ∣ 2 2 σ i 2 \int_x\frac{1}{\sqrt{2\pi}\sigma_i}e^{\frac{(x-\mu_t)^2}{2\sigma_i^2}}(-\frac{1}{2\sigma_i^2}[(x-\mu_\theta)^2])dx=\frac{\sigma_i^2+||\mu_\theta-\mu_t||^2}{2\sigma_i^2} ∫x2πσi1e2σi2(x−μt)2(−2σi21[(x−μθ)2])dx=2σi2σi2+∣∣μθ−μt∣∣2
因此: K L [ N ( μ t , σ i 2 I ) ∣ ∣ N ( μ θ , σ i 2 I ) ] = σ i 2 + ∣ ∣ μ θ − μ t ∣ ∣ 2 2 σ i 2 − 1 2 KL[N(\mu_t,\sigma_i^2I)||N(\mu_\theta,\sigma_i^2I)]=\frac{\sigma_i^2+||\mu_\theta-\mu_t||^2}{2\sigma_i^2}-\frac{1}{2} KL[N(μt,σi2I)∣∣N(μθ,σi2I)]=2σi2σi2+∣∣μθ−μt∣∣2−21
我们的 E q ∗ [ L t − 1 ] E_{q*}[L_{t-1}] Eq∗[Lt−1]优化目标即为如下所示
E q ∗ [ L t − 1 ] = ( E q ∗ [ ∣ ∣ μ t ( x t , x 0 ) − μ θ ( x t , t ) ∣ ∣ 2 2 σ i 2 ] ) + C E_{q*}[L_{t-1}]= (E_{q*}[\frac{||\mu_t(x_t,x_0)-\mu_\theta(x_t,t)||^2}{2\sigma_i^2}])+C Eq∗[Lt−1]=(Eq∗[2σi2∣∣μt(xt,x0)−μθ(xt,t)∣∣2])+C
而我们知道,事实上 x t x_t xt是通过采样生成的,也即 x t = ∏ i = 0 t a i x 0 + ( 1 − ∏ i = 0 t a i ) z x_t=\prod \limits_{i=0}^t \sqrt{a_i}x_0+\sqrt{(1-\prod \limits_{i=0}^t {a_i})}z xt=i=0∏taix0+(1−i=0∏tai)z,那么会有其中 x 0 = ( x t − ( 1 − ∏ i = 0 t a i ) z ) / ∏ i = 0 t a i x_0=(x_t-\sqrt{(1-\prod \limits_{i=0}^t {a_i})}z)/\prod \limits_{i=0}^t \sqrt{a_i} x0=(xt−(1−i=0∏tai)z)/i=0∏tai.
将它带入 μ t \mu_t μt中进行整理,那么事实上会得到:
μ t ( x t , x 0 ) = 1 a t ( x t − 1 − a t ( 1 − ∏ i = 0 k a i ) z ) \mu_t(x_t,x_0)=\frac{1}{\sqrt{a_t}}(x_t-\frac{1-a_t}{\sqrt{(1-\prod \limits_{i=0}^k {a_i})}}z) μt(xt,x0)=at1(xt−(1−i=0∏kai)1−atz)
E q ∗ [ L t − 1 ] − C = E x 0 , z [ 1 2 σ i 2 ∣ ∣ 1 a t ( x t − 1 − a t ( 1 − ∏ i = 0 k a i ) z ) − μ θ ( x t , t ) ∣ ∣ 2 ] E_{q*}[L_{t-1}]-C=E_{x_0,z}[\frac{1}{2 \sigma_i^2} ||\frac{1}{\sqrt{a_t}}(x_t-\frac{1-a_t}{\sqrt{(1-\prod \limits_{i=0}^k {a_i})}}z)-\mu_\theta(x_t,t)||^2] Eq∗[Lt−1]−C=Ex0,z[2σi21∣∣at1(xt−(1−i=0∏kai)1−atz)−μθ(xt,t)∣∣2]
注意到,此时输入给网络的 x t x_t xt是作为输入进去的,随机部分出在 z z z身上,如果我们令:
μ θ ( x t , t ) = 1 a t ( x t − 1 − a t ( 1 − ∏ i = 0 k a i ) z θ ( x t , t ) ) \mu_\theta(x_t,t)=\frac{1}{\sqrt{a_t}}(x_t-\frac{1-a_t}{\sqrt{(1-\prod \limits_{i=0}^k {a_i})}}z_\theta(x_t,t)) μθ(xt,t)=at1(xt−(1−i=0∏kai)1−atzθ(xt,t))
那么首先: E q ∗ [ L t − 1 ] − C = E_{q*}[L_{t-1}]-C= Eq∗[Lt−1]−C=
E x 0 , z [ 1 2 σ i 2 ∣ ∣ 1 a t ( x t − 1 − a t ( 1 − ∏ i = 0 k a i ) z ) − 1 a t ( x t − 1 − a t ( 1 − ∏ i = 0 k a i ) z θ ) ∣ ∣ 2 ] E_{x_0,z}[\frac{1}{2 \sigma_i^2}||\frac{1}{\sqrt{a_t}}(x_t-\frac{1-a_t}{\sqrt{(1-\prod \limits_{i=0}^k {a_i})}}z)-\frac{1}{\sqrt{a_t}}(x_t-\frac{1-a_t}{\sqrt{(1-\prod \limits_{i=0}^k {a_i})}}z_\theta)||^2] Ex0,z[2σi21∣∣at