本文在概览中对Diffusion model抛去细节做一个整体的梳理,而细节的推导会在下文的扩散过程、逆扩散过程、损失函数中展示。如果只想对Diffusion model有一个定性的了解而不关系推导的话,只看概览就可以了。
如上图所示,扩散过程为从右到左( X 0 → X T X_0 \rightarrow X_T X0→XT)的过程,表示对图片逐渐加噪,且 X t + 1 X_{t+1} Xt+1是在 X t X_{t} Xt上加躁得到的,其只受 X t X_{t} Xt的影响,因此扩散过程是一个马尔科夫过程。 X 0 X_0 X0表示从真实数据集中采样得到的一张图片,对 X 0 X_0 X0添加 T T T次噪声,图片逐渐变得模糊,当 T T T足够大时, X T X_T XT为标准正态分布。在训练过程中,每次添加的噪声是已知的,即 q ( X t ∣ X t − 1 ) q(X_t|X_{t-1}) q(Xt∣Xt−1)是已知的,根据马尔科夫过程的性质,我们可以递归得到 q ( X t ∣ X 0 ) q(X_t|X_0) q(Xt∣X0),即 q ( X t ∣ X 0 ) q(X_t|X_0) q(Xt∣X0)是已知的。扩散过程最主要的就是 q ( X t ∣ X 0 ) q(X_t|X_0) q(Xt∣X0)和 q ( X t ∣ X t − 1 ) q(X_t|X_{t-1}) q(Xt∣Xt−1)的推导,推导细节见下文的扩散过程。
如上图所示,逆扩散过程为从左到右( X T → X 0 X_T \rightarrow X_0 XT→X0)的过程,表示从噪声中逐渐复原出图片。如果我们能够在给定 X t X_t Xt条件下知道 X t − 1 X_{t-1} Xt−1的分布,即如果我们可以知道 q ( X t − 1 ∣ X t ) q(X_{t-1}|X_t) q(Xt−1∣Xt),那我们就能够从任意一张噪声图片中经过一次次的采样得到一张图片而达成图片生成的目的。显然我们很难知道 q ( X t − 1 ∣ X t ) q(X_{t-1}|X_t) q(Xt−1∣Xt),因此我们才会用 p Θ ( X t − 1 ∣ X t ) p_{Θ}(X_{t-1}|X_t) pΘ(Xt−1∣Xt)来近似 q ( X t − 1 ∣ X t ) q(X_{t-1}|X_t) q(Xt−1∣Xt), p Θ ( X t − 1 ∣ X t ) p_{Θ}(X_{t-1}|X_t) pΘ(Xt−1∣Xt)就是我们要训练的网络,在原文中就是个U-Net。而很妙的是,虽然我们不知道 q ( X t − 1 ∣ X t ) q(X_{t-1}|X_t) q(Xt−1∣Xt),但是 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)却是可以用 q ( X t ∣ X 0 ) q(X_t|X_0) q(Xt∣X0)和 q ( X t ∣ X t − 1 ) q(X_t|X_{t-1}) q(Xt∣Xt−1)表示的,即 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)是可知的,因此我们可以用 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)来指导 p Θ ( X t − 1 ∣ X t ) p_{Θ}(X_{t-1}|X_t) pΘ(Xt−1∣Xt)进行训练。逆扩散过程最主要的就是 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)的推导,推导细节见下文的逆扩散过程。
我们已经明确了要训练 p Θ ( X t − 1 ∣ X t ) p_{Θ}(X_{t-1}|X_t) pΘ(Xt−1∣Xt),那要怎么确定目标函数呢?有两个很直接的想法,一个是负对数的最大似然概率,即 − l o g p Θ ( X 0 ) -logp_{Θ}(X_0) −logpΘ(X0),另一个是真实分布与预测分布的交叉熵,即 − E q ( X 0 ) l o g p Θ ( X 0 ) -E_{q(X_0)}logp_{Θ}(X_0) −Eq(X0)logpΘ(X0),而显然这两种都不好搞,因此他参考了VAE,不去优化这两个东西,而是优化他们的变分上界(variational lower bound),定义 L V L B L_{VLB} LVLB如下:
L V L B = E q ( x 0 : T ) [ l o g q ( X 1 : T ∣ X 0 ) p Θ ( X 0 : T ) ] 可证: L V L B ≥ − l o g p Θ ( X 0 ) & L V L B ≥ − E q ( X 0 ) l o g p Θ ( X 0 ) (1) L_{VLB} = E_{q(x_{0:T})}[log\frac{q(X_{1:T}|X_0)}{p_{Θ}(X_{0:T})}]\tag{1} \\ 可证:L_{VLB} \ge -logp_{Θ}(X_0) \And L_{VLB} \ge -E_{q(X_0)}logp_{Θ}(X_0) LVLB=Eq(x0:T)[logpΘ(X0:T)q(X1:T∣X0)]可证:LVLB≥−logpΘ(X0)&LVLB≥−Eq(X0)logpΘ(X0)(1)
即 L V L B L_{VLB} LVLB减小就代表 − l o g p Θ ( X 0 ) -logp_{Θ}(X_0) −logpΘ(X0)和 − E q ( X 0 ) l o g p Θ ( X 0 ) -E_{q(X_0)}logp_{Θ}(X_0) −Eq(X0)logpΘ(X0)的上界减小。
且经过推导, L V L B L_{VLB} LVLB可写成如下形式:
L V L B = L T + L T − 1 + . . . + L 0 (2) L_{VLB} = L_{T} + L_{T-1} + ...+ L_{0}\tag{2} LVLB=LT+LT−1+...+L0(2) w h e r e : L T = D K L ( q ( X T ∣ X 0 ) ∣ ∣ p Θ ( X T ) ) L t = D K L ( q ( X t ∣ X t + 1 X 0 ) ∣ ∣ p Θ ( X t ∣ X t + 1 ) ) , 1 ≤ t ≤ T L 0 = − l o g p Θ ( X 0 ∣ X 1 ) where: L_{T} = D_{KL}(q(X_T|X_0)||p_{Θ}(X_{T}))\\ L_{t} = D_{KL}(q(X_t|X_{t+1}X_0)||p_{Θ}(X_{t}|X_{t+1})),1 \le t \le T\\ L_{0} = -logp_{Θ}(X_{0}|X_{1}) where:LT=DKL(q(XT∣X0)∣∣pΘ(XT))Lt=DKL(q(Xt∣Xt+1X0)∣∣pΘ(Xt∣Xt+1)),1≤t≤TL0=−logpΘ(X0∣X1)
由上式不难发现, L t L_{t} Lt就是逆扩散过程中 q ( X t ∣ X t + 1 X 0 ) q(X_{t}|X_{t+1}X_0) q(Xt∣Xt+1X0)和 p Θ ( X t ∣ X t + 1 ) p_{Θ}(X_{t}|X_{t+1}) pΘ(Xt∣Xt+1)的KL散度,这也就是我上面说的,用 q ( X t ∣ X t + 1 X 0 ) q(X_{t}|X_{t+1}X_0) q(Xt∣Xt+1X0)来指导 p Θ ( X t ∣ X t + 1 ) p_{Θ}(X_{t}|X_{t+1}) pΘ(Xt∣Xt+1)进行训练。这部分主要就是(1)式和(2)式的推导,细节部分见下文的损失函数。
如上图所示,扩散过程为从右到左( X 0 → X T X_0 \rightarrow X_T X0→XT)的过程,表示对图片逐渐加噪,且 X t + 1 X_{t+1} Xt+1是在 X t X_{t} Xt上加躁得到的,其只受 X t X_{t} Xt的影响,因此扩散过程是一个马尔科夫过程。且每一步扩散的步长受变量 { β t ∈ ( 0 , 1 ) } t = 1 T \{β_{t} \in (0,1)\}_{t=1}^{T} {βt∈(0,1)}t=1T的影响。 q ( X t ∣ X t − 1 ) q(X_{t}|X_{t-1}) q(Xt∣Xt−1)可写为如下形式,即给定 X t − 1 X_{t-1} Xt−1的条件下, X t X_{t} Xt服从均值为 1 − β t X t − 1 \sqrt{1-β_{t}}X_{t-1} 1−βtXt−1,方差为 β t I β_{t}I βtI的正态分布:
q ( X t ∣ X t − 1 ) = N ( X t ; 1 − β t X t − 1 , β t I ) (3) q(X_{t}|X_{t-1}) = N(X_t; \sqrt{1-β_{t}}X_{t-1},β_{t}I)\tag{3} q(Xt∣Xt−1)=N(Xt;1−βtXt−1,βtI)(3)
用重参数化技巧表示 X t X_{t} Xt,令 α t = 1 − β t α_{t}=1-β_{t} αt=1−βt,令 Z t ∼ N ( 0 , I ) , t ≥ 0 Z_{t} \sim N(0,I), t \ge 0 Zt∼N(0,I),t≥0,即:
X t = α t X t − 1 + 1 − α t Z t − 1 (4) X_{t}= \sqrt{α_{t}}X_{t-1}+\sqrt{1-α_{t}}Z_{t-1}\tag{4} Xt=αtXt−1+1−αtZt−1(4)
写多几行:
X t − 1 = α t − 1 X t − 2 + 1 − α t − 1 Z t − 2 X t − 2 = α t − 2 X t − 3 + 1 − α t − 2 Z t − 3 . . . X 1 = α 1 X 0 + 1 − α 1 Z 0 X_{t-1}= \sqrt{α_{t-1}}X_{t-2}+\sqrt{1-α_{t-1}}Z_{t-2}\\ X_{t-2}= \sqrt{α_{t-2}}X_{t-3}+\sqrt{1-α_{t-2}}Z_{t-3}\\ ...\\ X_{1}= \sqrt{α_{1}}X_{0}+\sqrt{1-α_{1}}Z_{0} Xt−1=αt−1Xt−2+1−αt−1Zt−2Xt−2=αt−2Xt−3+1−αt−2Zt−3...X1=α1X0+1−α1Z0
易归纳得,令 α ˉ t = ∏ i = 1 t α i \bar{α}_{t}= {\textstyle \prod_{i=1}^{t}α_{i}} αˉt=∏i=1tαi:
X t = α ˉ t X 0 + α ˉ t α 1 1 − α 1 Z 0 + α ˉ t α ˉ 2 1 − α 2 Z 1 + α ˉ t α ˉ 3 1 − α 3 Z 2 + . . . + 1 − α t Z t − 1 X_{t}= \sqrt{\bar{α}_{t}}X_{0}+\frac{\sqrt{\bar{α}_{t}}}{\sqrt{α_{1}}}\sqrt{1-α_{1}}Z_{0}+\frac{\sqrt{\bar{α}_{t}}}{\sqrt{\bar{α}_{2}}}\sqrt{1-α_{2}}Z_{1}+\frac{\sqrt{\bar{α}_{t}}}{\sqrt{\bar{α}_{3}}}\sqrt{1-α_{3}}Z_{2}+...+\sqrt{1-α_{t}}Z_{t-1} Xt=αˉtX0+α1αˉt1−α1Z0+αˉ2αˉt1−α2Z1+αˉ3αˉt1−α3Z2+...+1−αtZt−1
设随机变量 Z ˉ t − 1 \bar{Z}_{t-1} Zˉt−1为:
Z ˉ t − 1 = α ˉ t α 1 1 − α 1 Z 0 + α ˉ t α ˉ 2 1 − α 2 Z 1 + α ˉ t α ˉ 3 1 − α 3 Z 2 + . . . + 1 − α t Z t − 1 \bar{Z}_{t-1}=\frac{\sqrt{\bar{α}_{t}}}{\sqrt{α_{1}}}\sqrt{1-α_{1}}Z_{0}+\frac{\sqrt{\bar{α}_{t}}}{\sqrt{\bar{α}_{2}}}\sqrt{1-α_{2}}Z_{1}+\frac{\sqrt{\bar{α}_{t}}}{\sqrt{\bar{α}_{3}}}\sqrt{1-α_{3}}Z_{2}+...+\sqrt{1-α_{t}}Z_{t-1} Zˉt−1=α1αˉt1−α1Z0+αˉ2αˉt1−α2Z1+αˉ3αˉt1−α3Z2+...+1−αtZt−1
则 Z ˉ t − 1 \bar{Z}_{t-1} Zˉt−1的期望和方差如下:
E ( Z ˉ t − 1 ) = 0 D ( Z ˉ t − 1 ) = α ˉ t α 1 ( 1 − α 1 ) + α ˉ t α ˉ 2 ( 1 − α 2 ) + α ˉ t α ˉ 3 ( 1 − α 3 ) + . . . + α ˉ t α ˉ t ( 1 − α t ) = 1 − α ˉ t E(\bar{Z}_{t-1})=0\\ D(\bar{Z}_{t-1})=\frac{{\bar{α}_{t}}}{{α_{1}}}(1-α_{1})+\frac{{\bar{α}_{t}}}{{\bar{α}_{2}}}(1-α_{2})+\frac{{\bar{α}_{t}}}{{\bar{α}_{3}}}(1-α_{3})+...+\frac{{\bar{α}_{t}}}{{\bar{α}_{t}}}(1-α_{t})=1-\bar{α}_{t} E(Zˉt−1)=0D(Zˉt−1)=α1αˉt(1−α1)+αˉ2αˉt(1−α2)+αˉ3αˉt(1−α3)+...+αˉtαˉt(1−αt)=1−αˉt
所以,
X t = α ˉ t X 0 + Z ˉ t − 1 = α ˉ t X 0 + 1 − α ˉ t Z , Z ∼ N ( 0 , I ) q ( X t ∣ X 0 ) = N ( X t ; α ˉ t X 0 , ( 1 − α ˉ t ) I ) X_{t}= \sqrt{\bar{α}_{t}}X_{0}+\bar{Z}_{t-1}=\sqrt{\bar{α}_{t}}X_{0}+\sqrt{1-\bar{α}_{t}}Z, Z\sim N(0,I)\\ q(X_{t}|X_0)=N(X_{t};\sqrt{\bar{α}_{t}}X_0,(1-\bar{α}_t)I) Xt=αˉtX0+Zˉt−1=αˉtX0+1−αˉtZ,Z∼N(0,I)q(Xt∣X0)=N(Xt;αˉtX0,(1−αˉt)I)
至此,我们推出了 q ( X t ∣ X t − 1 ) q(X_{t}|X_{t-1}) q(Xt∣Xt−1)和 q ( X t ∣ X 0 ) q(X_{t}|X_{0}) q(Xt∣X0)。
如果我们能够在给定 X t X_t Xt条件下知道 X t − 1 X_{t-1} Xt−1的分布,即如果我们可以知道 q ( X t − 1 ∣ X t ) q(X_{t-1}|X_t) q(Xt−1∣Xt),那我们就能够从任意一张噪声图片中经过一次次的采样得到一张图片而达成图片生成的目的。显然我们很难知道 q ( X t − 1 ∣ X t ) q(X_{t-1}|X_t) q(Xt−1∣Xt),因此我们才会用 p Θ ( X t − 1 ∣ X t ) p_{Θ}(X_{t-1}|X_t) pΘ(Xt−1∣Xt)来近似 q ( X t − 1 ∣ X t ) q(X_{t-1}|X_t) q(Xt−1∣Xt), p Θ ( X t − 1 ∣ X t ) p_{Θ}(X_{t-1}|X_t) pΘ(Xt−1∣Xt)就是我们要训练的网络。而很妙的是,虽然我们不知道 q ( X t − 1 ∣ X t ) q(X_{t-1}|X_t) q(Xt−1∣Xt),但是 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)却是可以用 q ( X t ∣ X 0 ) q(X_t|X_0) q(Xt∣X0)和 q ( X t ∣ X t − 1 ) q(X_t|X_{t-1}) q(Xt∣Xt−1)表示的,即 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)是可知的。
下面对 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)进行推导:
q ( X t − 1 ∣ X t X 0 ) = q ( X 0 X t − 1 X t ) q ( X 0 X t ) = q ( X 0 X t − 1 X t ) q ( X 0 X t − 1 ) q ( X 0 X t − 1 ) q ( X 0 X t ) = q ( X t ∣ X t − 1 X 0 ) ∗ q ( X t − 1 ∣ X 0 ) q ( X t ∣ X 0 ) ∵ 扩散过程是马尔科夫过程 ∴ q ( X t ∣ X t − 1 X 0 ) = q ( X t ∣ X t − 1 ) ∴ q ( X t − 1 ∣ X t X 0 ) = q ( X t ∣ X t − 1 ) ∗ q ( X t − 1 ∣ X 0 ) q ( X t ∣ X 0 ) q(X_{t-1}|X_tX_0)=\frac{q(X_0X_{t-1}X_t)}{q(X_{0}X_t)}=\frac{q(X_0X_{t-1}X_t)}{q(X_{0}X_{t-1})}\frac{q(X_{0}X_{t-1})}{q(X_{0}X_{t})}=q(X_t|X_{t-1}X_0)*\frac{q(X_{t-1}|X_{0})}{q(X_{t}|X_{0})}\\ \because 扩散过程是马尔科夫过程\\ \therefore q(X_t|X_{t-1}X_0)=q(X_t|X_{t-1})\\ \therefore q(X_{t-1}|X_tX_0)=q(X_t|X_{t-1})*\frac{q(X_{t-1}|X_{0})}{q(X_{t}|X_{0})} q(Xt−1∣XtX0)=q(X0Xt)q(X0Xt−1Xt)=q(X0Xt−1)q(X0Xt−1Xt)q(X0Xt)q(X0Xt−1)=q(Xt∣Xt−1X0)∗q(Xt∣X0)q(Xt−1∣X0)∵扩散过程是马尔科夫过程∴q(Xt∣Xt−1X0)=q(Xt∣Xt−1)∴q(Xt−1∣XtX0)=q(Xt∣Xt−1)∗q(Xt∣X0)q(Xt−1∣X0)
至此,已经把 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)用 q ( X t ∣ X 0 ) q(X_t|X_0) q(Xt∣X0)和 q ( X t ∣ X t − 1 ) q(X_t|X_{t-1}) q(Xt∣Xt−1)进行表示,下面对 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)的表达式进行推导:
q ( X t ∣ X t − 1 ) = N ( X t ; 1 − β t X t − 1 , β t I ) = 1 2 π ( 1 − α t ) e x p ( − 1 2 ( X t − α t X t − 1 ) 2 1 − α t ) q ( X t ∣ X 0 ) = N ( X t ; α ˉ t X 0 , ( 1 − α t ˉ ) I ) = 1 2 π ( 1 − α ˉ t ) e x p ( − 1 2 ( X t − α ˉ t X 0 ) 2 1 − α ˉ t ) q ( X t − 1 ∣ X 0 ) = N ( X t − 1 ; α ˉ t − 1 X 0 , ( 1 − α ˉ t − 1 ) I ) = 1 2 π ( 1 − α ˉ t − 1 ) e x p ( − 1 2 ( X t − 1 − α ˉ t − 1 X 0 ) 2 1 − α ˉ t − 1 ) q(X_t|X_{t-1})=N(X_t; \sqrt{1-β_{t}}X_{t-1},β_{t}I)=\frac{1}{\sqrt{2\pi(1-α_{t})}}exp(-\frac{1}{2}\frac{(X_t-\sqrt{α_t}X_{t-1})^2}{1-α_t})\\ q(X_{t}|X_0)=N(X_{t};\sqrt{\bar{α}_{t}}X_0,(1-\bar{α_t})I)=\frac{1}{\sqrt{2\pi(1-\bar{α}_{t})}}exp(-\frac{1}{2}\frac{(X_t-\sqrt{\bar{α}_t}X_{0})^2}{1-\bar{α}_t})\\ q(X_{t-1}|X_0)=N(X_{t-1};\sqrt{\bar{α}_{t-1}}X_0,(1-\bar{α}_{t-1})I)=\frac{1}{\sqrt{2\pi(1-\bar{α}_{t-1})}}exp(-\frac{1}{2}\frac{(X_{t-1}-\sqrt{\bar{α}_{t-1}}X_{0})^2}{1-\bar{α}_{t-1}}) q(Xt∣Xt−1)=N(Xt;1−βtXt−1,βtI)=2π(1−αt)1exp(−211−αt(Xt−αtXt−1)2)q(Xt∣X0)=N(Xt;αˉtX0,(1−αtˉ)I)=2π(1−αˉt)1exp(−211−αˉt(Xt−αˉtX0)2)q(Xt−1∣X0)=N(Xt−1;αˉt−1X0,(1−αˉt−1)I)=2π(1−αˉt−1)1exp(−211−αˉt−1(Xt−1−αˉt−1X0)2)
q ( X t − 1 ∣ X t X 0 ) = 1 2 π 1 − α ˉ t − 1 1 − α ˉ t β t e x p ( − 1 2 1 − α ˉ t − 1 1 − α ˉ t β t ( X t − 1 2 − 2 ( ( 1 − α ˉ t − 1 ) α t X t 1 − α ˉ t + β t α ˉ t − 1 X 0 1 − α ˉ t ) X t − 1 + C ( X 0 , X t ) ) q ( X t − 1 ∣ X t X 0 ) = N ( X t − 1 ; ( 1 − α ˉ t − 1 ) α t X t 1 − α ˉ t + β t α ˉ t − 1 X 0 1 − α ˉ t , 1 − α ˉ t − 1 1 − α ˉ t β t ) ∵ X t = α ˉ t X 0 + 1 − α ˉ t Z , Z ∼ N ( 0 , I ) ∴ q ( X t − 1 ∣ X t X 0 ) = N ( X t − 1 ; 1 α t X t − β t α t ( 1 − α ˉ t ) Z , 1 − α ˉ t − 1 1 − α ˉ t β t ) , Z ∼ N ( 0 , I ) q(X_{t-1}|X_tX_0)=\frac{1}{\sqrt{2\pi\frac{1-\bar{α}_{t-1}}{1-\bar{α}_{t}}}β_t}exp(-\frac{1}{2\frac{1-\bar{α}_{t-1}}{1-\bar{α}_{t}}β_t}(X_{t-1}^{2}-2(\frac{(1-\bar{α}_{t-1})\sqrt{α_t}X_t}{1-\bar{α}_t}+\frac{β_t\sqrt{\bar{α}_{t-1}}X_0}{1-\bar{α}_t})X_{t-1}+C(X_0,X_t))\\ q(X_{t-1}|X_tX_0)=N(X_{t-1};\frac{(1-\bar{α}_{t-1})\sqrt{α_t}X_t}{1-\bar{α}_t}+\frac{β_t\sqrt{\bar{α}_{t-1}}X_0}{1-\bar{α}_t},\frac{1-\bar{α}_{t-1}}{1-\bar{α}_{t}}β_t)\\ \because X_{t}= \sqrt{\bar{α}_{t}}X_{0}+\sqrt{1-\bar{α}_{t}}Z,Z\sim N(0,I)\\ \therefore q(X_{t-1}|X_tX_0)=N(X_{t-1}; \frac{1}{\sqrt{α}_t}X_t-\frac{β_t}{\sqrt{α_t(1-\bar{α}_t)}}Z ,\frac{1-\bar{α}_{t-1}}{1-\bar{α}_{t}}β_t),Z\sim N(0,I) q(Xt−1∣XtX0)=2π1−αˉt1−αˉt−1βt1exp(−21−αˉt1−αˉt−1βt1(Xt−12−2(1−αˉt(1−αˉt−1)αtXt+1−αˉtβtαˉt−1X0)Xt−1+C(X0,Xt))q(Xt−1∣XtX0)=N(Xt−1;1−αˉt(1−αˉt−1)αtXt+1−αˉtβtαˉt−1X0,1−αˉt1−αˉt−1βt)∵Xt=αˉtX0+1−αˉtZ,Z∼N(0,I)∴q(Xt−1∣XtX0)=N(Xt−1;αt1Xt−αt(1−αˉt)βtZ,1−αˉt1−αˉt−1βt),Z∼N(0,I)
至此,得到了 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)的分布表达式。
在下文的损失函数中,会介绍我们要怎么用 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)来监督 p Θ ( X t − 1 ∣ X t ) p_{Θ}(X_{t-1}|X_t) pΘ(Xt−1∣Xt)进行训练。
我们已经明确了要训练 p Θ ( X t − 1 ∣ X t ) p_{Θ}(X_{t-1}|X_t) pΘ(Xt−1∣Xt),那要怎么确定目标函数呢?有两个很直接的想法,一个是负对数的最大似然概率,即 − l o g p Θ ( X 0 ) -logp_{Θ}(X_0) −logpΘ(X0),另一个是真实分布与预测分布的交叉熵,即 − E q ( X 0 ) l o g p Θ ( X 0 ) -E_{q(X_0)}logp_{Θ}(X_0) −Eq(X0)logpΘ(X0),然而,类似于VAE,由于我们很难对噪声空间进行积分,因此直接优化 − l o g p Θ ( X 0 ) -logp_{Θ}(X_0) −logpΘ(X0)或 − E q ( X 0 ) l o g p Θ ( X 0 ) -E_{q(X_0)}logp_{Θ}(X_0) −Eq(X0)logpΘ(X0)是很困难的,因此我们不会直接优化它们,而是优化它们的变分上界 L V L B L_{VLB} LVLB, L V L B L_{VLB} LVLB的定义如下:
L V L B = E q ( x 0 : T ) [ l o g q ( X 1 : T ∣ X 0 ) p Θ ( X 0 : T ) ] L_{VLB} = E_{q(x_{0:T})}[log\frac{q(X_{1:T}|X_0)}{p_{Θ}(X_{0:T})}] LVLB=Eq(x0:T)[logpΘ(X0:T)q(X1:T∣X0)]
下面证明 L V L B L_{VLB} LVLB是 − l o g p Θ ( X 0 ) -logp_{Θ}(X_0) −logpΘ(X0)和 − E q ( X 0 ) l o g p Θ ( X 0 ) -E_{q(X_0)}logp_{Θ}(X_0) −Eq(X0)logpΘ(X0)的上界,即证明 L V L B ≥ − l o g p Θ ( X 0 ) & L V L B ≥ − E q ( X 0 ) l o g p Θ ( X 0 ) L_{VLB} \ge -logp_{Θ}(X_0) \And L_{VLB} \ge -E_{q(X_0)}logp_{Θ}(X_0) LVLB≥−logpΘ(X0)&LVLB≥−Eq(X0)logpΘ(X0):
− l o g p Θ ( X 0 ) ≤ − l o g p Θ ( X 0 ) + D K L ( q ( X 1 : t ∣ X 0 ) ∣ ∣ p Θ ( X 1 : T ∣ X 0 ) ) = − l o g p Θ ( X 0 ) + E X 1 : T ∼ q ( X 1 : T ∣ X 0 ) ( l o g q ( X 1 : T ∣ X 0 ) p Θ ( X 1 : T ∣ X 0 ) ) = − l o g p Θ ( X 0 ) + E X 1 : T ∼ q ( X 1 : T ∣ X 0 ) ( l o g q ( X 1 : T ∣ X 0 ) p Θ ( X 0 ) p Θ ( X 0 : T ) ) = − l o g p Θ ( X 0 ) + E X 1 : T ∼ q ( X 1 : T ∣ X 0 ) ( l o g q ( X 1 : T ∣ X 0 ) p Θ ( X 0 : T ) + l o g ( p Θ ( X 0 ) ) ) = E X 0 : T ∼ q ( X 0 : T ) ( l o g q ( X 1 : T ∣ X 0 ) p Θ ( X 0 : T ) ) = L V L B -logp_{Θ}(X_0) \le -logp_{Θ}(X_0) + D_{KL}(q(X_{1:t}|X_0)||p_{Θ}(X_{1:T}|X_0))\\=-logp_{Θ}(X_0)+E_{X_{1:T}\sim q(X_{1:T}|X_0)}(log\frac{q(X_{1:T}|X_0)}{p_{Θ}(X_{1:T}|X_0)})\\=-logp_{Θ}(X_0)+E_{X_{1:T}\sim q(X_{1:T}|X_0)}(log\frac{q(X_{1:T}|X_0)p_{Θ}(X_{0})}{p_{Θ}(X_{0:T})})\\=-logp_{Θ}(X_0)+E_{X_{1:T}\sim q(X_{1:T}|X_0)}(log\frac{q(X_{1:T}|X_0)}{p_{Θ}(X_{0:T})}+log(p_{Θ}(X_{0})))\\=E_{X_{0:T}\sim q(X_{0:T})}(log\frac{q(X_{1:T}|X_0)}{p_{Θ}(X_{0:T})})=L_{VLB} −logpΘ(X0)≤−logpΘ(X0)+DKL(q(X1:t∣X0)∣∣pΘ(X1:T∣X0))=−logpΘ(X0)+EX1:T∼q(X1:T∣X0)(logpΘ(X1:T∣X0)q(X1:T∣X0))=−logpΘ(X0)+EX1:T∼q(X1:T∣X0)(logpΘ(X0:T)q(X1:T∣X0)pΘ(X0))=−logpΘ(X0)+EX1:T∼q(X1:T∣X0)(logpΘ(X0:T)q(X1:T∣X0)+log(pΘ(X0)))=EX0:T∼q(X0:T)(logpΘ(X0:T)q(X1:T∣X0))=LVLB
L C E = − ∫ q ( X 0 ) l o g p Θ ( X 0 ) d X 0 = − E q ( X 0 ) l o g p Θ ( X 0 ) = − E q ( X 0 ) l o g ( ∫ p Θ ( X 1 : T ∣ X 0 ) p Θ ( X 0 ) d X 1 : T ) = − E q ( X 0 ) l o g ( ∫ p Θ ( X 0 : T ) d X 1 : T ) = − E q ( X 0 ) l o g ( ∫ q ( X 1 : T ∣ X 0 ) p Θ ( X 0 : T ) q ( X 1 : T ∣ X 0 ) d X 1 : T ) = − E q ( X 0 ) ( l o g ( E q ( X 1 : T ∣ X 0 ) p Θ ( X 0 : T ) q ( X 1 : T ∣ X 0 ) ) ) ≤ − E q ( X 0 ) ( E q ( X 1 : T ∣ X 0 ) l o g ( p Θ ( X 0 : T ) q ( X 1 : T ∣ X 0 ) ) ) = E q ( x 0 : T ) [ l o g q ( X 1 : T ∣ X 0 ) p Θ ( X 0 : T ) ] = L V L B L_{CE}=-\int q(X_0)logp_{Θ}(X_0)dX_0=-E_{q(X_0)}logp_{Θ}(X_0)\\=-E_{q(X_0)}log(\int p_{Θ}(X_{1:T}|X_0)p_{Θ}(X_0)dX_{1:T})\\=-E_{q(X_0)}log(\int p_{Θ}(X_{0:T})dX_{1:T})\\=-E_{q(X_0)}log(\int q(X_{1:T}|X_0)\frac{p_{Θ}(X_{0:T})}{q(X_{1:T}|X_0)} dX_{1:T})\\=-E_{q(X_0)}(log(E_{q(X_{1:T}|X_0)}\frac{p_{Θ}(X_{0:T})}{q(X_{1:T}|X_0)}))\\\le-E_{q(X_0)}(E_{q(X_{1:T}|X_0)}log(\frac{p_{Θ}(X_{0:T})}{q(X_{1:T}|X_0)}))\\=E_{q(x_{0:T})}[log\frac{q(X_{1:T}|X_0)}{p_{Θ}(X_{0:T})}]=L_{VLB} LCE=−∫q(X0)logpΘ(X0)dX0=−Eq(X0)logpΘ(X0)=−Eq(X0)log(∫pΘ(X1:T∣X0)pΘ(X0)dX1:T)=−Eq(X0)log(∫pΘ(X0:T)dX1:T)=−Eq(X0)log(∫q(X1:T∣X0)q(X1:T∣X0)pΘ(X0:T)dX1:T)=−Eq(X0)(log(Eq(X1:T∣X0)q(X1:T∣X0)pΘ(X0:T)))≤−Eq(X0)(Eq(X1:T∣X0)log(q(X1:T∣X0)pΘ(X0:T)))=Eq(x0:T)[logpΘ(X0:T)q(X1:T∣X0)]=LVLB
至此,证明了 L V L B L_{VLB} LVLB是 − l o g p Θ ( X 0 ) -logp_{Θ}(X_0) −logpΘ(X0)和 − E q ( X 0 ) l o g p Θ ( X 0 ) -E_{q(X_0)}logp_{Θ}(X_0) −Eq(X0)logpΘ(X0)的上界。
下面,对 L V L B L_{VLB} LVLB化简:
L V L B = E q ( x 0 : T ) [ l o g q ( X 1 : T ∣ X 0 ) p Θ ( X 0 : T ) ] = E q ( x 0 : T ) [ l o g ∏ t = 1 T q ( X t ∣ X t − 1 ) p Θ ( X T ) ∏ t = 1 T p Θ ( X t − 1 ∣ X t ) ] = E q ( x 0 : T ) [ − l o g p Θ ( X T ) + ∑ t = 1 T l o g q ( X t ∣ X t − 1 ) p Θ ( X t − 1 ∣ X t ) ] = E q ( x 0 : T ) [ − l o g p Θ ( X T ) + ∑ t = 2 T l o g q ( X t ∣ X t − 1 ) p Θ ( X t − 1 ∣ X t ) + l o g q ( X 1 ∣ X 0 ) p Θ ( X 0 ∣ X 1 ) ] = E q ( X 0 : T ) [ − l o g p Θ ( X T ) + ∑ t = 2 T l o g ( q ( X t − 1 ∣ X t X 0 ) p Θ ( X t − 1 ∣ X t ) ∗ q ( X t ∣ X 0 ) q ( X t − 1 ∣ X 0 ) ) + l o g q ( X 1 ∣ X 0 ) p Θ ( X 0 ∣ X 1 ) ] = E q ( x 0 : T ) [ − l o g p Θ ( X T ) + ∑ t = 2 T l o g q ( X t − 1 ∣ X t X 0 ) p Θ ( X t − 1 ∣ X t ) + ∑ t = 2 T l o g q ( X t ∣ X 0 ) q ( X t − 1 ∣ X 0 ) + l o g q ( X 1 ∣ X 0 ) p Θ ( X 0 ∣ X 1 ) ] = E q ( x 0 : T ) [ − l o g p Θ ( X T ) + ∑ t = 2 T l o g q ( X t − 1 ∣ X t X 0 ) p Θ ( X t − 1 ∣ X t ) + l o g q ( X T ∣ X 0 ) q ( X 1 ∣ X 0 ) + l o g q ( X 1 ∣ X 0 ) p Θ ( X 0 ∣ X 1 ) ] = E q ( x 0 : T ) [ l o g q ( X T ∣ X 0 ) p Θ ( X T ) + ∑ t = 2 T l o g q ( X t − 1 ∣ X t X 0 ) p Θ ( X t − 1 ∣ X t ) − l o g p Θ ( X 0 ∣ X 1 ) ] = D K L ( q ( X T ∣ X 0 ) ∣ ∣ p Θ ( X T ) ) + ∑ t = 2 T D K L ( q ( X t − 1 ∣ X t X 0 ) ∣ ∣ p Θ ( X t − 1 ∣ X t ) ) − l o g p Θ ( X 0 ∣ X 1 ) = L T + L T − 1 + . . . + L 0 w h e r e : L T = D K L ( q ( X T ∣ X 0 ) ∣ ∣ p Θ ( X T ) ) L t = D K L ( q ( X t ∣ X t + 1 X 0 ) ∣ ∣ p Θ ( X t ∣ X t + 1 ) ) , 1 ≤ t ≤ T L 0 = − l o g p Θ ( X 0 ∣ X 1 ) L_{VLB} = E_{q(x_{0:T})}[log\frac{q(X_{1:T}|X_0)}{p_{Θ}(X_{0:T})}]\\=E_{q(x_{0:T})}[log\frac{\textstyle \prod_{t=1}^{T}q(X_{t}|X_{t-1})}{p_{Θ}(X_{T}){\textstyle \prod_{t=1}^{T}}p_{Θ}(X_{t-1}|X_t)}]\\=E_{q(x_{0:T})}[-logp_{Θ}(X_T)+\sum_{t=1}^{T}log\frac{q(X_{t}|X_{t-1})}{p_{Θ}(X_{t-1}|X_t)}]\\=E_{q(x_{0:T})}[-logp_{Θ}(X_T)+\sum_{t=2}^{T}log\frac{q(X_{t}|X_{t-1})}{p_{Θ}(X_{t-1}|X_t)}+log\frac{q(X_{1}|X_{0})}{p_{Θ}(X_{0}|X_1)}]\\=E_{q(X_{0:T})}[-logp_{Θ}(X_T)+\sum_{t=2}^{T}log(\frac{q(X_{t-1}|X_{t}X_0)}{p_{Θ}(X_{t-1}|X_t)}*\frac{q(X_{t}|X_0)}{q(X_{t-1}|X_0)} )+log\frac{q(X_{1}|X_{0})}{p_{Θ}(X_{0}|X_1)}]\\=E_{q(x_{0:T})}[-logp_{Θ}(X_T)+\sum_{t=2}^{T}log\frac{q(X_{t-1}|X_{t}X_0)}{p_{Θ}(X_{t-1}|X_t)}+\sum_{t=2}^{T}log\frac{q(X_{t}|X_0)}{q(X_{t-1}|X_0)}+log\frac{q(X_{1}|X_{0})}{p_{Θ}(X_{0}|X_1)}]\\=E_{q(x_{0:T})}[-logp_{Θ}(X_T)+\sum_{t=2}^{T}log\frac{q(X_{t-1}|X_{t}X_0)}{p_{Θ}(X_{t-1}|X_t)}+log\frac{q(X_{T}|X_0)}{q(X_{1}|X_0)}+log\frac{q(X_{1}|X_{0})}{p_{Θ}(X_{0}|X_1)}]\\=E_{q(x_{0:T})}[log\frac{q(X_T|X_0)}{p_{Θ}(X_T)} +\sum_{t=2}^{T}log\frac{q(X_{t-1}|X_{t}X_0)}{p_{Θ}(X_{t-1}|X_t)}-logp_{Θ}(X_{0}|X_1)]\\=D_{KL}(q(X_T|X_0)||p_{Θ}(X_T))+\sum_{t=2}^{T} D_{KL}(q(X_{t-1}|X_tX_0)||p_{Θ}(X_{t-1}|X_t))-logp_{Θ}(X_{0}|X_1)\\= L_{T} + L_{T-1} + ...+ L_{0}\\where: L_{T} = D_{KL}(q(X_T|X_0)||p_{Θ}(X_{T}))\\ L_{t} = D_{KL}(q(X_t|X_t+1X_0)||p_{Θ}(X_{t}|X_{t+1})),1 \le t \le T\\ L_{0} = -logp_{Θ}(X_{0}|X_{1}) LVLB=Eq(x0:T)[logpΘ(X0:T)q(X1:T∣X0)]=Eq(x0:T)[logpΘ(XT)∏t=1TpΘ(Xt−1∣Xt)∏t=1Tq(Xt∣Xt−1)]=Eq(x0:T)[−logpΘ(XT)+t=1∑TlogpΘ(Xt−1∣Xt)q(Xt∣Xt−1)]=Eq(x0:T)[−logpΘ(XT)+t=2∑TlogpΘ(Xt−1∣Xt)q(Xt∣Xt−1)+logpΘ(X0∣X1)q(X1∣X0)]=Eq(X0:T)[−logpΘ(XT)+t=2∑Tlog(pΘ(Xt−1∣Xt)q(Xt−1∣XtX0)∗q(Xt−1∣X0)q(Xt∣X0))+logpΘ(X0∣X1)q(X1∣X0)]=Eq(x0:T)[−logpΘ(XT)+t=2∑TlogpΘ(Xt−1∣Xt)q(Xt−1∣XtX0)+t=2∑Tlogq(Xt−1∣X0)q(Xt∣X0)+logpΘ(X0∣X1)q(X1∣X0)]=Eq(x0:T)[−logpΘ(XT)+t=2∑TlogpΘ(Xt−1∣Xt)q(Xt−1∣XtX0)+logq(X1∣X0)q(XT∣X0)+logpΘ(X0∣X1)q(X1∣X0)]=Eq(x0:T)[logpΘ(XT)q(XT∣X0)+t=2∑TlogpΘ(Xt−1∣Xt)q(Xt−1∣XtX0)−logpΘ(X0∣X1)]=DKL(q(XT∣X0)∣∣pΘ(XT))+t=2∑TDKL(q(Xt−1∣XtX0)∣∣pΘ(Xt−1∣Xt))−logpΘ(X0∣X1)=LT+LT−1+...+L0where:LT=DKL(q(XT∣X0)∣∣pΘ(XT))Lt=DKL(q(Xt∣Xt+1X0)∣∣pΘ(Xt∣Xt+1)),1≤t≤TL0=−logpΘ(X0∣X1)
从 L t L_{t} Lt即可看出,对 p Θ ( X t ∣ X t + 1 ) p_{Θ}(X_{t}|X_{t+1}) pΘ(Xt∣Xt+1)的监督就是最小化 p Θ ( X t ∣ X t + 1 ) p_{Θ}(X_{t}|X_{t+1}) pΘ(Xt∣Xt+1)和 q ( X t ∣ X t + 1 X 0 ) q(X_t|X_{t+1}X_0) q(Xt∣Xt+1X0)的KL散度。
简单的说,我们的目的是希望学习出一个 p Θ ( X t − 1 ∣ X t ) p_{Θ}(X_{t-1}|X_t) pΘ(Xt−1∣Xt),即能够从噪声图恢复出原图。
为了达到这一个目的,我们使用 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)来监督 p Θ ( X t − 1 ∣ X t ) p_{Θ}(X_{t-1}|X_t) pΘ(Xt−1∣Xt)进行训练, q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)是可以用 q ( X t ∣ X 0 ) q(X_t|X_0) q(Xt∣X0)和 q ( X t ∣ X t − 1 ) q(X_t|X_{t-1}) q(Xt∣Xt−1)表示的,即 q ( X t − 1 ∣ X t X 0 ) q(X_{t-1}|X_tX_0) q(Xt−1∣XtX0)是已知的。
本文是我学习过程中的个人理解,有不对的地方希望大家帮忙指出。希望可以抛砖引玉,欢迎大家在评论区和我交流。