diffusion model原理和算法伪代码

文章目录

    • Diffusion model
      • 前置数学知识
      • VAE和多层VAE回顾
        • 1. 单层VAE的原理
        • 2. 多层VAE的原理
      • Diffusion model
      • 扩散过程(Diffusion Process)
      • 逆扩散过程(Reverse Process)
      • 后验的扩散条件概率 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)存在闭式解
      • 目标-数据分布的似然函数
      • Diffusion Probabilistic Model的算法代码
        • Training
        • Sampling

Diffusion model

奠基性的工作:

  1. Ho,(2020),Denoising diffusion peobabilistic models
  2. Sohi,(2015), Deep unsupervised learning using nonequilibruim thermodynamics

前置数学知识

  1. 条件概率的一般形式
    P ( B , C ∣ A ) = P ( B ∣ A ) P ( C ∣ A , B ) P(B,C|A)=P(B|A)P(C|A,B) P(B,CA)=P(BA)P(CA,B)

  2. 基于马尔可夫假设的条件概率

    假设马尔可夫链关系 A → B → C A\to B\to C ABC,有
    P ( A , B , C ) = P ( C ∣ B ) P ( B ∣ A ) P ( A ) P(A,B,C)=P(C|B)P(B|A)P(A) P(A,B,C)=P(CB)P(BA)P(A)

  3. 高斯分布的KL散度

    对于两个单一变量的高斯分布p和q而言,他们的KL散度满足
    K L ( N ( μ 1 , σ 1 2 ) , N ( μ 2 , σ 2 2 ) ) = log ⁡ σ 2 σ 1 − 1 2 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 KL(\mathcal{N}(\mu_1,\sigma_1^2),\mathcal{N}(\mu_2,\sigma_2^2))=\log\frac{\sigma_2}{\sigma_1}-\frac{1}{2}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2} KL(N(μ1,σ12),N(μ2,σ22))=logσ1σ221+2σ22σ12+(μ1μ2)2
    推导详见CSDN博客

  4. 参数重整化

    若希望从高斯分布 N ( μ , σ 2 ) \mathcal{N}(\mu,\sigma^2) N(μ,σ2)中采样,可以先从标准分布 N ( 0 , 1 ) \mathcal{N}(0,1) N(0,1)得到 z z z,得到 σ ⋅ z + μ \sigma\cdot z+\mu σz+μ

    这样就可以将 σ \sigma σ μ \mu μ也作为仿射网络的一部分,而不是不可导的环境参数。

    这个技巧在VAE和Diffusion model中大量被使用。

VAE和多层VAE回顾

1. 单层VAE的原理

x → z , q ϕ ( z ∣ x ) z → x , p θ ( x ∣ z ) x\to z,\quad q_{\phi}(z|x)\\ z\to x,\quad p_{\theta}(x|z) xz,qϕ(zx)zx,pθ(xz)

此时 x x x的边缘概率分布可以改写为关于z的期望式
p ( x ) = ∫ z p θ ( x ∣ z ) p ( z ) d z = ∫ z q ϕ ( z ∣ x ) p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) d z = E z ∼ q ϕ ( z ∣ x ) p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) \begin{aligned} p(x)&=\int_zp_\theta(x|z)p(z)\text{d}z\\ &=\int_zq_\phi(z|x)\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)}\text{d}z\\ &=\mathbb{E}_{z\sim q_\phi(z|x)}\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)} \end{aligned} p(x)=zpθ(xz)p(z)dz=zqϕ(zx)qϕ(zx)pθ(xz)p(z)dz=Ezqϕ(zx)qϕ(zx)pθ(xz)p(z)
此时的Evidence存在一个lower bound(ELBO)
log ⁡ p ( x ) = log ⁡ E z ∼ q ϕ ( z ∣ x ) p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) ≥ E z ∼ q ϕ ( z ∣ x ) log ⁡ [ p θ ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) ] \log p(x)=\log\mathbb{E}_{z\sim q_\phi(z|x)}\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)} \ge\mathbb{E}_{z\sim q_\phi(z|x)}\log\left[\frac{p_\theta(x|z)p(z)}{q_\phi(z|x)}\right] logp(x)=logEzqϕ(zx)qϕ(zx)pθ(xz)p(z)Ezqϕ(zx)log[qϕ(zx)pθ(xz)p(z)]
在训练中,我们需要最大化对数似然,即Evidence,可以通过最小化lower bound实现,而这个lower bound可以分为两部分:

  1. E z ∼ q ϕ ( z ∣ x ) p θ ( x ∣ z ) \mathbb{E}_{z\sim q_\phi(z|x)}p_\theta(x|z) Ezqϕ(zx)pθ(xz),可以通过神经网络实现预测
  2. − E z ∼ q ϕ ( z ∣ x ) log ⁡ q ϕ ( x ∣ z ) p ( z ) -\mathbb{E}_{z\sim q_\phi(z|x)}\log \frac{q_\phi(x|z)}{p(z)} Ezqϕ(zx)logp(z)qϕ(xz),即两个分布的KL的散度,一般可假设 z z z服从高斯分布,而 q ϕ ( x ∣ z ) q_\phi(x|z) qϕ(xz)也逼近高斯分布,而两个高斯分布的KL散度存在公式

所以,单层VAE的损失函数是可优化的。

2. 多层VAE的原理

diffusion model原理和算法伪代码_第1张图片

基于同样的原理,
p ( x ) = ∫ z 1 ∫ z 2 p θ ( x , z 1 , z 2 ) d z 1 d z 2 = ∫ z 1 ∫ z 2 q ϕ ( z 1 , z 2 ∣ x ) p θ ( x , z 1 , z 2 ) q ϕ ( z 1 , z 2 ∣ x ) d z 1 d z 2 = E z 1 , z 2 ∼ q ϕ ( z 1 , z 2 ∣ x ) p θ ( x , z 1 , z 2 ) q ϕ ( z 1 , z 2 ∣ x ) \begin{aligned} p(x)&=\int_{z_1}\int_{z_2}p_\theta(x,z_1,z_2)\text{d}z_1\text{d}z_2\\ &=\int_{z_1}\int_{z_2}q_\phi(z_1,z_2|x)\frac{p_\theta(x,z_1,z_2)}{q_\phi(z_1,z_2|x)}\text{d}z_1\text{d}z_2\\ &=\mathbb{E}_{z1,z_2\sim q_\phi(z_1,z_2|x)}\frac{p_\theta(x,z_1,z_2)}{q_\phi(z_1,z_2|x)} \end{aligned} p(x)=z1z2pθ(x,z1,z2)dz1dz2=z1z2qϕ(z1,z2x)qϕ(z1,z2x)pθ(x,z1,z2)dz1dz2=Ez1,z2qϕ(z1,z2x)qϕ(z1,z2x)pθ(x,z1,z2)
得到
log ⁡ p ( x ) ≥ E z 1 , z 2 ∼ q ϕ ( z 1 , z 2 ∣ x ) log ⁡ p θ ( x , z 1 , z 2 ) q ϕ ( z 1 , z 2 ∣ x ) \log p(x)\ge \mathbb{E}_{z1,z_2\sim q_\phi(z_1,z_2|x)}\log \frac{p_\theta(x,z_1,z_2)}{q_\phi(z_1,z_2|x)} logp(x)Ez1,z2qϕ(z1,z2x)logqϕ(z1,z2x)pθ(x,z1,z2)
如果上述过程满足马尔科夫假设,即
p θ ( x , z 1 , z 2 ) = p ( x ∣ z 1 ) p ( z 1 ∣ z 2 ) p ( z 2 ) q ( z 1 , z 2 ∣ x ) = q ( z 1 ∣ x ) q ( z 2 ∣ z 1 ) p_\theta(x,z_1,z_2)=p(x|z_1)p(z_1|z_2)p(z_2)\\ q(z_1,z_2|x)=q(z_1|x)q(z_2|z_1) pθ(x,z1,z2)=p(xz1)p(z1z2)p(z2)q(z1,z2x)=q(z1x)q(z2z1)
(6)式能够被进一步简化为
L ( θ , ϕ ) = E q ( z 1 , z 2 ∣ x ) [ log ⁡ p ( x ∣ z 1 ) − log ⁡ q ( z 1 ∣ x ) + log ⁡ p ( z 1 ∣ z 2 ) − log ⁡ q ( z 2 ∣ z 1 ) + log ⁡ p ( z 2 ) ] \mathcal{L}(\theta,\phi)=\mathbb{E}_{q(z_1,z_2|x)} \left[ \log p(x|z_1)-\log q(z_1|x)+\log p(z_1|z_2) -\log q(z_2|z_1) +\log p(z_2) \right] L(θ,ϕ)=Eq(z1,z2x)[logp(xz1)logq(z1x)+logp(z1z2)logq(z2z1)+logp(z2)]

Diffusion model

diffusion model原理和算法伪代码_第2张图片

从右往左,从目标分布到噪声分布称为扩散过程,而我们希望学习到从左往右的逆扩散过程。上图中的第一行从左往右是扩散过程,第二行从右往左是逆扩散过程,而第三行是前两者的差值,称为偏移量。

扩散过程(Diffusion Process)

  1. 给定初始数据分布 x 0 ∼ q ( x ) \bold{x_0}\sim q(\bold{x}) x0q(x),不断向分布中添加高斯噪声,噪声的标准差是以 β t \beta_t βt确定的,均值是以固定值 β t \beta_t βt和当前时刻的数据 x t \bold{x_t} xt决定的,所以该过程并没有需要学习的参数,而且是一个马尔科夫链过程。

  2. 随着 t t t的不断增大,最终数据分布 x T x_T xT变成了一个各项独立的高斯分布
    q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(\bold{x_t|x_{t-1}})=\mathcal{N}(\bold{x_t};\sqrt{1-\beta_t}\bold{x_{t-1},\beta_t\bold{I}}) q(xtxt1)=N(xt;1βt xt1,βtI)

    q ( x 1 : T ∣ x o ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(\bold{x_{1:T}|x_o})=\prod^{T}_{t=1}q(\bold{x_t|x_{t-1}}) q(x1:Txo)=t=1Tq(xtxt1)

    这充分体现了参数重整化的技巧。

  3. 任意时刻的 q ( x t ) q(\bold{x_t}) q(xt)推导也可以完全基于 x 0 \bold{x}_0 x0 β t \beta_t βt计算得到闭式解,而不需要做迭代。(令 α t = 1 − β t \alpha_t=1-\beta_t αt=1βt

    两个正态分布 X ∼ N ( μ 1 , σ 1 ) X\sim \mathcal{N}(\mu_1,\sigma_1) XN(μ1,σ1) Y ∼ N ( μ 2 , σ 2 ) Y\sim \mathcal{N}(\mu_2,\sigma_2) YN(μ2,σ2)叠加后的分布 a X + b Y aX+bY aX+bY服从分布 N ( a μ 1 + b μ 2 , a 2 σ 1 2 + b 2 σ 2 2 ) \mathcal{N}(a\mu_1+b\mu_2,a^2\sigma_1^2+b^2\sigma_2^2) N(aμ1+bμ2,a2σ12+b2σ22)

    对于第 t t t步的分布 x t x_t xt等于上一步的分布 x t − 1 x_{t-1} xt1加上高斯噪声 z t − 1 z_{t-1} zt1,即
    x t = α t x t − 1 + 1 − α t z t − 1 ; where  z t − 1 , z t − 2 , . . . ∼ N ( 0 , I ) = α t α t − 1 x t − 2 + α t − α t α t − 1 z t − 2 + 1 − α t z t − 1 \begin{aligned} \bold{x}_t&=\sqrt{\alpha_t}\bold{x}_{t-1}+\sqrt{1-\alpha_t}\bold{z}_{t-1}\qquad ;\text{where} \ \bold{z_{t-1}},\bold{z_{t-2}},...\sim \mathcal{N}(\bold{0},\bold{I})\\ &=\sqrt{\alpha_t\alpha_{t-1}}\bold{x}_{t-2}+{\color{red} \sqrt{\alpha_t-\alpha_t\alpha_{t-1}}\bold{z}_{t-2}+\sqrt{1-\alpha_t}\bold{z_{t-1}}} \end{aligned} xt=αt xt1+1αt zt1;where zt1,zt2,...N(0,I)=αtαt1 xt2+αtαtαt1 zt2+1αt zt1
    这里借助参数重整化的技巧,将红色部分的两个高斯分布合并为新的高斯分布,整理如下所示
    x t = α t α t − 1 x t − 2 + 1 − α t α t − 1 z ˉ t − 2 \begin{aligned} \bold{x}_t&=\sqrt{\alpha_t\alpha_{t-1}}\bold{x}_{t-2}+{\color{red} \sqrt{1-\alpha_t\alpha_{t-1}}\bar{\bold{z}}_{t-2}} \end{aligned} xt=αtαt1 xt2+1αtαt1 zˉt2
    其中, z ˉ t − 2 ∼ N ( 0 , I ) \bar{\bold{z}}_{t-2}\sim \mathcal{N}(\bold{0},\bold{I}) zˉt2N(0,I)

    重复上面的步骤,最终可以得到 z t \bold{z}_t zt的闭式解
    x t = α ˉ t x 0 + 1 − α ˉ t z ; where  α ˉ t = ∏ i = 1 T α i \bold{x}_t=\sqrt{\bar{\alpha}_t}\bold{x}_0+\sqrt{1-\bar{\alpha}_{t}}\bold{z}\qquad ;\text{where}\ \bar{\alpha}_t=\prod_{i=1}^T\alpha_i xt=αˉt x0+1αˉt z;where αˉt=i=1Tαi
    此时,作者认为 x t ∼ N ( x t ; α ˉ t x 0 , 1 − α ˉ t I ) \bold{x}_t\sim \mathcal{N}(\bold{x}_t;\sqrt{\bar{\alpha}_t}\bold{x}_0,\sqrt{1-\bar{\alpha}_t}\bold{I}) xtN(xt;αˉt x0,1αˉt I),(此处应该是认为 x 0 \bold{x}_0 x0是完全已知的,方差为零),最终当上述分布趋近于 N ( 0 , I ) \mathcal{N}(\bold{0},\bold{I}) N(0,I)的时候,可认为模型已经基本完成扩散过程。因此,作者给出了一种 β t \beta_t βt的设置经验, β 1 < β 2 < ⋅ ⋅ ⋅ < β T \beta_1<\beta_2<\cdot\cdot\cdot<\beta_T β1<β2<<βT,即随着扩散深度的加深,逐步扩大 β \beta β

    逆扩散过程(Reverse Process)

    逆过程是从高斯分布中恢复原始数据,当 β t \beta_t βt足够小时,逆过程的每一小步 p θ ( x t − 1 ∣ x t ) p_\theta(\bold{x}_{t-1}|\bold{x}_t) pθ(xt1xt)也可视作高斯分布,即
    p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) p_\theta(\bold{x}_{t-1}|\bold{x}_t)=\mathcal{N}(\bold{x}_{t-1};\bold{\mu_\theta}(\bold{x}_t,t),\sum_\theta(\bold{x}_t,t)) pθ(xt1xt)=N(xt1;μθ(xt,t),θ(xt,t))
    逆扩散过程可以被总结为如下形式
    p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p_\theta(\bold{x}_{0:T})=p(\bold{x}_T)\prod_{t=1}^Tp_\theta (\bold{x}_{t-1}|\bold{x}_t) pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)
    此处通过使用网络估计参数 θ \theta θ以实现逆扩散过程。

    后验的扩散条件概率 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)存在闭式解

    根据条件概率的贝叶斯公式
    q ( x t − 1 ∣ x t , x 0 ) q ( x t ∣ x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0)q(\bold{x}_t|\bold{x}_0)=q(\bold{x}_{t}|\bold{x}_{t-1},\bold{x}_0)q(\bold{x}_{t-1}|\bold{x}_0) q(xt1xt,x0)q(xtx0)=q(xtxt1,x0)q(xt1x0)
    得到
    q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) ∝ exp ⁡ ( − 1 2 ( ( x t − α t x t − 1 ) 2 1 − α t + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) = exp ⁡ ( − 1 2 ( a x t − 1 2 + b x t − 1 + c ( x t , x 0 ) ) ) \begin{aligned} q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0)&=q(\bold{x}_{t}|\bold{x}_{t-1},\bold{x}_0)\frac{q(\bold{x}_{t-1}|\bold{x}_0)}{q(\bold{x}_t|\bold{x}_0)}\\ &\propto \exp \left(-\frac{1}{2}\left(\frac{(\bold{x}_t-\sqrt{\alpha_t}\bold{x}_{t-1})^2}{1-\alpha_t}+\frac{(\bold{x}_{t-1}-\sqrt{\bar{\alpha}_{t-1}}\bold{x}_0)^2}{1-\bar{\alpha}_{t-1}}-\frac{(\bold{x}_{t}-\sqrt{\bar{\alpha}_{t}}\bold{x}_0)^2}{1-\bar{\alpha}_{t}}\right)\right)\\ &=\exp\left(-\frac{1}{2}\left({\color{blue} a}\bold{x}_{t-1}^2+{\color{red} b}\bold{x}_{t-1}+c(\bold{x}_t,\bold{x}_0)\right)\right) \end{aligned} q(xt1xt,x0)=q(xtxt1,x0)q(xtx0)q(xt1x0)exp(21(1αt(xtαt xt1)2+1αˉt1(xt1αˉt1 x0)21αˉt(xtαˉt x0)2))=exp(21(axt12+bxt1+c(xt,x0)))
    可见,上述分布的核心可以用一个二次函数来描述,那对应的中轴线应该是
    μ ~ t ( x t , x 0 ) = − b 2 a = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ x 0 \begin{aligned} \bold{\tilde\mu_t}(\bold{x}_t,\bold{x}_0)&=-\frac{\color{red}{b}}{2\color{blue}{a}}\\ &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\bold{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}}\bold{x}_0 \end{aligned} μ~t(xt,x0)=2ab=1αˉtαt (1αˉt1)xt+1αˉαˉt1 βtx0
    容易从扩散过程的表达式(式11)得到 x 0 \bold{x}_0 x0的表达式
    x 0 = 1 α ˉ t ( x t − 1 − α ˉ t z ) \bold{x}_0=\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\bold{x}_t-\sqrt{1-\bar{\alpha}_{t}}\bold{z}\right) x0=αˉt 1(xt1αˉt z)
    带入得到
    μ ~ t ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ 1 α ˉ t ( x t − 1 − α ˉ t z ) = 1 α t ( x t − β t 1 − α ˉ t z t ) \begin{aligned} \bold{\tilde\mu_t}(\bold{x}_t,\bold{x}_0) &=\frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\bold{x}_t+\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}}\frac{1}{\sqrt{\bar{\alpha}_t}}\left(\bold{x}_t-\sqrt{1-\bar{\alpha}_{t}}\bold{z}\right)\\ &=\color{green}{\frac{1}{\sqrt{\alpha}_t}\left(\bold{x}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\bold{z}_t\right)} \end{aligned} μ~t(xt,x0)=1αˉtαt (1αˉt1)xt+1αˉαˉt1 βtαˉt 1(xt1αˉt z)=α t1(xt1αˉt βtzt)
    这就是 x t − 1 \bold{x}_{t-1} xt1分布的均值表达式,即给定 x 0 \bold{x}_0 x0的条件下,后验条件高斯分布的均值计算只与 x t \bold{x}_{t} xt z t \bold{z}_t zt有关。

目标-数据分布的似然函数

我们在待优化的目标数据分布的似然函数(负)上加一个非负的KL散度,构成负对数似然的上界,通过最小化上界,负对数似然自然取得最小。
− log ⁡ p θ ( x 0 ) = − log ⁡ p θ ( x 0 ) + D K L ( q ( x 1 : T ∣ x 0 ) ∣ ∣ p θ ( x 1 : T ∣ x 0 ) ) = − log ⁡ p θ ( x 0 ) + E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) / p θ ( x 0 ) ] = − log ⁡ p θ ( x 0 ) + E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) + log ⁡ p θ ( x 0 ) ] = E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] = L V L B \begin{aligned} -\log p_\theta(\bold{x}_0) &=-\log p_\theta(\bold{x}_0) + D_{KL}(q(\bold{x}_{1:T}|\bold{x}_0)||p_\theta(\bold{x}_{1:T}|\bold{x}_0))\\ &=-\log p_\theta(\bold{x}_0) +\mathbb{E}_{\bold{x}_{1:T}\sim q(\bold{x}_{1:T}|\bold{x}_0)}\left[\log \frac{q(\bold{x}_{1:T}|\bold{x}_0)}{p_\theta(\bold{x}_{0:T})/p_\theta (\bold{x}_0)}\right]\\ &=-\log p_\theta(\bold{x}_0) +\mathbb{E}_{\bold{x}_{1:T}\sim q(\bold{x}_{1:T}|\bold{x}_0)}\left[\log \frac{q(\bold{x}_{1:T}|\bold{x}_0)}{p_\theta(\bold{x}_{0:T})}+\log p_\theta (\bold{x}_0)\right]\\ &=\mathbb{E}_{\bold{x}_{1:T}\sim q(\bold{x}_{1:T}|\bold{x}_0)}\left[\log \frac{q(\bold{x}_{1:T}|\bold{x}_0)}{p_\theta(\bold{x}_{0:T})}\right]\qquad \color{blue}{=L_{VLB}} \end{aligned} logpθ(x0)=logpθ(x0)+DKL(q(x1:Tx0)pθ(x1:Tx0))=logpθ(x0)+Ex1:Tq(x1:Tx0)[logpθ(x0:T)/pθ(x0)q(x1:Tx0)]=logpθ(x0)+Ex1:Tq(x1:Tx0)[logpθ(x0:T)q(x1:Tx0)+logpθ(x0)]=Ex1:Tq(x1:Tx0)[logpθ(x0:T)q(x1:Tx0)]=LVLB
我们也可以继续对 L V L B L_{VLB} LVLB进行展开,过程比较繁琐,建议查看论文,最终的形式如下
L V L B = E q [ D K L ( q ( x T ∣ x 0 ) ∣ ∣ p θ ( x T ) ) + ∑ t = 1 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) ] {\color{blue}{L_{VLB}}}=\mathbb{E}_q\left[D_{KL}\left(q(\bold{x}_T|\bold{x}_0)||p_\theta (\bold{x}_T)\right)+{\color{red} \sum_{t=1}^TD_{KL}(q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0)||p_\theta(\bold{x}_{t-1}|\bold{x}_t))}\right] LVLB=Eq[DKL(q(xTx0)pθ(xT))+t=1TDKL(q(xt1xt,x0)pθ(xt1xt))]
其中第一项是不含待优化参数的,仅仅需要优化第二项即可。而且作者将 p θ ( x t − 1 ∣ x t ) p_\theta(\bold{x}_{t-1}|\bold{x}_t) pθ(xt1xt)的方差设置为与 β \beta β有关的常数,可训练参数仅存在其均值中。

我们已经知道 q ( x t − 1 ∣ x t , x 0 ) q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0) q(xt1xt,x0)服从高斯分布,并给出了其均值的表达式,而且知道 p θ ( x t − 1 ∣ x t ) p_\theta(\bold{x}_{t-1}|\bold{x}_t) pθ(xt1xt)也服从高斯分布,其方差设置为常数,仅需优化均值即可。使用文章开头给出的两个单一变量的高斯分布的KL散度表达式,两个分布的方差均为常数,最终的损失函数可以写作两个分布的均值的关系:
L t − 1 = E q [ 1 2 σ t 2 ∣ ∣ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∣ ∣ 2 ] + C {\color{red} L_{t-1}}=\mathbb{E}_q\left[\frac{1}{2\sigma_t^2}||\tilde{\bold\mu}_t(\bold{x}_t,\bold{x}_0)-\mu_\theta(\bold{x}_t,t)||^2\right]+C Lt1=Eq[2σt21μ~t(xt,x0)μθ(xt,t)2]+C
我们可以将已经得到的 μ t \mu_t μt的表达式,进行简化得到最终的损失函数:
L simple ( θ ) : = E t , x 0 , ϵ [ ∣ ∣ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∣ ∣ 2 ] L_{\text{simple}}(\theta):=\mathbb{E}_{t,\bold{x}_0,\bold\epsilon}\left[||\bold\epsilon-\bold\epsilon_\theta(\sqrt{\bar{\alpha}_t}\bold{x}_0+\sqrt{1-\bar{\alpha}_t}\bold\epsilon,t)||^2\right] Lsimple(θ):=Et,x0,ϵ[ϵϵθ(αˉt x0+1αˉt ϵ,t)2]
这里, ϵ θ \epsilon_\theta ϵθ就是可学习的网络,输入 x 0 \bold{x}_0 x0和高斯噪声 ϵ \epsilon ϵ以及时刻 t t t

Diffusion Probabilistic Model的算法代码

Training

  1. repeat
  2. x 0 ∼ q ( x 0 ) x_0\sim q(x_0) x0q(x0)
  3. t ∼ Uniform ( { 1 , 2 , . . , T } ) t\sim \text{Uniform}(\{1,2,..,T\}) tUniform({1,2,..,T})
  4. ϵ ∼ N ( 0 , I ) \epsilon\sim \mathcal{N}(\bold{0},\bold{I}) ϵN(0,I)
  5. Take gradient descent step on
  6. ∇ θ ∣ ∣ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∣ ∣ 2 \nabla _\theta||\bold\epsilon-\bold\epsilon_\theta(\sqrt{\bar{\alpha}_t}\bold{x}_0+\sqrt{1-\bar{\alpha}_t}\bold\epsilon,t)||^2 θϵϵθ(αˉt x0+1αˉt ϵ,t)2
  7. until converged

Sampling

优化好网络 ϵ θ \epsilon_\theta ϵθ之后,可以从 x T x_T xT逐步获得 x 0 x_0 x0

  1. x T ∼ N ( 0 , I ) x_T\sim\mathcal{N}(\bold{0},\bold{I}) xTN(0,I)
  2. for t = T , T − 1 , . . . , 1 t=T,T-1,...,1 t=T,T1,...,1 do
  3. z ∼ N ( 0 , I ) \bold{z}\sim{\mathcal{N}(\bold{0},\bold{I})} zN(0,I) if t > 1 t>1 t>1 else z = 0 \bold{z}=\bold{0} z=0
  4. x t − 1 = μ θ ( x t , t ) + σ t z = 1 α t ( x − 1 − α t 1 − α t ϵ θ ( x t , t ) + σ t z \bold{x}_{t-1}=\mu_\theta(\bold{x}_t,t)+\sigma_t\bold{z}=\frac{1}{\sqrt{\alpha_t}}\left(\bold{x}-\frac{1-\alpha_t}{\sqrt{1-\alpha_t}}\epsilon_\theta(\bold{x}_t,t\right)+\sigma_t\bold{z} xt1=μθ(xt,t)+σtz=αt 1(x1αt 1αtϵθ(xt,t)+σtz
  5. end for
  6. return x 0 x_0 x0

你可能感兴趣的:(视频编码,计算机视觉,论文笔记,算法,机器学习,python)