DDPM = 自回归式 VAE

Contents

  • 多步突破
  • 联合散度
  • 分而治之
  • 场景再现
  • 超参设置
  • References

  • DDPM 本质上已经不是传统的扩散模型了,它更多的是一个变分自编码器 VAE,实际上 DDPM 的原论文中也是将它按照 VAE 的思路进行推导的。所以,下面就从 VAE 的角度来重新介绍 DDPM

多步突破

  • 传统的 VAE 中,编码过程和生成过程都是一步到位的。这样做就只涉及到三个分布:编码分布 p ( z ∣ x ) p(z|x) p(zx)、生成分布 q ( x ∣ z ) q(x|z) q(xz) 以及先验分布 q ( z ) q(z) q(z),它的好处是形式比较简单, x x x z z z 之间的映射关系也比较确定,因此可以同时得到编码模型和生成模型,实现隐变量编辑等需求;但是它的缺点也很明显,因为我们建模概率分布的能力有限,三个分布都只能建模为正态分布,这限制了模型的表达能力,最终通常得到偏模糊的生成结果
  • 为了突破这个限制,DDPM 将编码过程和生成过程分解为 T T T,每一步编码过程 p ( x t ∣ x t − 1 ) p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1}) p(xtxt1) 和生成过程 q ( x t − 1 ∣ x t ) q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) q(xt1xt) 仅仅负责建模一个微小变化,它们依然建模为正态分布。对于微小变化来说,可以用正态分布足够近似地建模,类似于曲线在小范围内可以用直线近似,多步分解就有点像用分段线性函数拟合复杂曲线,因此理论上可以突破传统单步 VAE 的拟合能力限制。编码和生成公式为
    x t = α t x t − 1 + β t ε t , ε t ∼ N ( 0 , I ) x t − 1 = 1 α t ( x t − β t ϵ θ ( x t , t ) ) + σ t z = μ ( x t ) + σ t z , z ∼ N ( 0 , I ) \begin{aligned} \boldsymbol{x}_t &= \alpha_t \boldsymbol{x}_{t-1} + \beta_t \boldsymbol{\varepsilon}_t,\quad \boldsymbol{\varepsilon}_t\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})\\ \boldsymbol{x}_{t-1} &= \frac{1}{\alpha_t}\left(\boldsymbol{x}_t - \beta_t \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right) + \sigma_t \boldsymbol{z} \\&=\boldsymbol{\mu}(\boldsymbol{x}_t)+ \sigma_t \boldsymbol{z},\quad \boldsymbol{z}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I}) \end{aligned} xtxt1=αtxt1+βtεt,εtN(0,I)=αt1(xtβtϵθ(xt,t))+σtz=μ(xt)+σtz,zN(0,I)可以将编码过程和生成过程写为下式:
    p ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , β t 2 I ) q ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ ( x t ) , σ t 2 I ) p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})=\mathcal{N}(\boldsymbol{x}_t;\alpha_t \boldsymbol{x}_{t-1}, \beta_t^2 \boldsymbol{I})\\ q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)=\mathcal{N}(\boldsymbol{x}_{t-1};\boldsymbol{\mu}(\boldsymbol{x}_t), \sigma_t^2 \boldsymbol{I}) p(xtxt1)=N(xt;αtxt1,βt2I)q(xt1xt)=N(xt1;μ(xt),σt2I)可以看到,在编码时,传统 VAE 的均值方差都是用神经网络学习出来的,而 DDPM 放弃了模型的编码能力,最终只得到一个纯粹的生成模型;在生成时,DDPM 则将生成过程建模成均值向量可学习的正态分布 N ( x t − 1 ; μ ( x t ) , σ t 2 I ) \mathcal{N}(\boldsymbol{x}_{t-1};\boldsymbol{\mu}(\boldsymbol{x}_t), \sigma_t^2 \boldsymbol{I}) N(xt1;μ(xt),σt2I)。整个模型拥有可训练参数的就只有 μ ( x t ) \boldsymbol{\mu}(\boldsymbol{x}_t) μ(xt)

联合散度

  • 理解 VAE 最简洁的理论途径,就是将其理解为在最小化联合分布的 KL 散度,对于 DDPM 也是如此。生成和编码的联合分布分别为
    p ( x 0 , x 1 , x 2 , ⋯   , x T ) = p ( x T ∣ x T − 1 ) ⋯ p ( x 2 ∣ x 1 ) p ( x 1 ∣ x 0 ) p ~ ( x 0 ) q ( x 0 , x 1 , x 2 , ⋯   , x T ) = q ( x 0 ∣ x 1 ) ⋯ q ( x T − 2 ∣ x T − 1 ) q ( x T − 1 ∣ x T ) q ( x T ) \begin{aligned}&p(\boldsymbol{x}_0, \boldsymbol{x}_1, \boldsymbol{x}_2, \cdots, \boldsymbol{x}_T) = p(\boldsymbol{x}_T|\boldsymbol{x}_{T-1})\cdots p(\boldsymbol{x}_2|\boldsymbol{x}_1) p(\boldsymbol{x}_1|\boldsymbol{x}_0) \tilde{p}(\boldsymbol{x}_0) \\ &q(\boldsymbol{x}_0, \boldsymbol{x}_1, \boldsymbol{x}_2, \cdots, \boldsymbol{x}_T) = q(\boldsymbol{x}_0|\boldsymbol{x}_1)\cdots q(\boldsymbol{x}_{T-2}|\boldsymbol{x}_{T-1}) q(\boldsymbol{x}_{T-1}|\boldsymbol{x}_T) q(\boldsymbol{x}_T) \end{aligned} p(x0,x1,x2,,xT)=p(xTxT1)p(x2x1)p(x1x0)p~(x0)q(x0,x1,x2,,xT)=q(x0x1)q(xT2xT1)q(xT1xT)q(xT)DDPM 的目的就是最小化两个联合分布之间的 KL 散度
    K L ( p ∥ q ) = ∫ p ( x T ∣ x T − 1 ) ⋯ p ( x 1 ∣ x 0 ) p ~ ( x 0 ) log ⁡ p ( x T ∣ x T − 1 ) ⋯ p ( x 1 ∣ x 0 ) p ~ ( x 0 ) q ( x 0 ∣ x 1 ) ⋯ q ( x T − 1 ∣ x T ) q ( x T ) d x 0 d x 1 ⋯ d x T KL(p\Vert q) = \int p(\boldsymbol{x}_T|\boldsymbol{x}_{T-1})\cdots p(\boldsymbol{x}_1|\boldsymbol{x}_0) \tilde{p}(\boldsymbol{x}_0) \log \frac{p(\boldsymbol{x}_T|\boldsymbol{x}_{T-1})\cdots p(\boldsymbol{x}_1|\boldsymbol{x}_0) \tilde{p}(\boldsymbol{x}_0)}{q(\boldsymbol{x}_0|\boldsymbol{x}_1)\cdots q(\boldsymbol{x}_{T-1}|\boldsymbol{x}_T) q(\boldsymbol{x}_T)} d\boldsymbol{x}_0 d\boldsymbol{x}_1\cdots d\boldsymbol{x}_T KL(pq)=p(xTxT1)p(x1x0)p~(x0)logq(x0x1)q(xT1xT)q(xT)p(xTxT1)p(x1x0)p~(x0)dx0dx1dxT

分而治之

  • 下面对上述 KL 散度进行化简。由于目前分布 p p p 不含任何的可训练参数,因此关于 p p p 的积分就只是贡献一个可以忽略的常数,所以 KL 散度等价于
      − ∫ p ( x T ∣ x T − 1 ) ⋯ p ( x 1 ∣ x 0 ) p ~ ( x 0 ) log ⁡ q ( x 0 ∣ x 1 ) ⋯ q ( x T − 1 ∣ x T ) q ( x T ) d x 0 d x 1 ⋯ d x T =   − ∫ p ( x T ∣ x T − 1 ) ⋯ p ( x 1 ∣ x 0 ) p ~ ( x 0 ) [ log ⁡ q ( x T ) + ∑ t = 1 T log ⁡ q ( x t − 1 ∣ x t ) ] d x 0 d x 1 ⋯ d x T \begin{aligned}&\,-\int p(\boldsymbol{x}_T|\boldsymbol{x}_{T-1})\cdots p(\boldsymbol{x}_1|\boldsymbol{x}_0) \tilde{p}(\boldsymbol{x}_0) \log q(\boldsymbol{x}_0|\boldsymbol{x}_1)\cdots q(\boldsymbol{x}_{T-1}|\boldsymbol{x}_T) q(\boldsymbol{x}_T) d\boldsymbol{x}_0 d\boldsymbol{x}_1\cdots d\boldsymbol{x}_T \\ =&\,-\int p(\boldsymbol{x}_T|\boldsymbol{x}_{T-1})\cdots p(\boldsymbol{x}_1|\boldsymbol{x}_0) \tilde{p}(\boldsymbol{x}_0) \left[\log q(\boldsymbol{x}_T) + \sum_{t=1}^T\log q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)\right] d\boldsymbol{x}_0 d\boldsymbol{x}_1\cdots d\boldsymbol{x}_T \end{aligned} =p(xTxT1)p(x1x0)p~(x0)logq(x0x1)q(xT1xT)q(xT)dx0dx1dxTp(xTxT1)p(x1x0)p~(x0)[logq(xT)+t=1Tlogq(xt1xt)]dx0dx1dxT由于先验分布 q ( x T ) q(\boldsymbol{x}_T) q(xT) 一般都取标准正态分布,也是没有参数的,所以这一项也只是贡献一个常数。因此需要计算的就是每一项
      − ∫ p ( x T ∣ x T − 1 ) ⋯ p ( x 1 ∣ x 0 ) p ~ ( x 0 ) log ⁡ q ( x t − 1 ∣ x t ) d x 0 d x 1 ⋯ d x T =   − ∫ p ( x t ∣ x t − 1 ) ⋯ p ( x 1 ∣ x 0 ) p ~ ( x 0 ) log ⁡ q ( x t − 1 ∣ x t ) d x 0 d x 1 ⋯ d x t =   − ∫ p ( x t ∣ x t − 1 ) p ( x t − 1 , x t − 2 , . . . , x 1 ∣ x 0 ) p ~ ( x 0 ) log ⁡ q ( x t − 1 ∣ x t ) d x 0 d x t − 1 d x t = − ∫ p ( x t ∣ x t − 1 ) p ( x t − 1 ∣ x 0 ) p ~ ( x 0 ) log ⁡ q ( x t − 1 ∣ x t ) d x 0 d x t − 1 d x t = ∫ p ( x t ∣ x t − 1 ) p ( x t − 1 ∣ x 0 ) [ E x 0 ∼ p ~ ( x 0 ) − log ⁡ q ( x t − 1 ∣ x t ) ] d x t − 1 d x t \begin{aligned}&\,-\int p(\boldsymbol{x}_T|\boldsymbol{x}_{T-1})\cdots p(\boldsymbol{x}_1|\boldsymbol{x}_0) \tilde{p}(\boldsymbol{x}_0) \log q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) d\boldsymbol{x}_0 d\boldsymbol{x}_1\cdots d\boldsymbol{x}_T\\ =&\,-\int p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})\cdots p(\boldsymbol{x}_1|\boldsymbol{x}_0) \tilde{p}(\boldsymbol{x}_0) \log q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) d\boldsymbol{x}_0 d\boldsymbol{x}_1\cdots d\boldsymbol{x}_t\\ =&\,-\int p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1},\boldsymbol{x}_{t-2},...,\boldsymbol{x}_{1}|\boldsymbol{x}_0) \tilde{p}(\boldsymbol{x}_0) \log q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) d\boldsymbol{x}_0 d\boldsymbol{x}_{t-1}d\boldsymbol{x}_t \\=&-\int p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0) \tilde{p}(\boldsymbol{x}_0) \log q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) d\boldsymbol{x}_0 d\boldsymbol{x}_{t-1}d\boldsymbol{x}_t \\=&\int p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)\left[\mathbb E_{\boldsymbol{x}_0\sim\tilde{p}(\boldsymbol{x}_0)}- \log q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)\right] d\boldsymbol{x}_{t-1}d\boldsymbol{x}_t \end{aligned} ====p(xTxT1)p(x1x0)p~(x0)logq(xt1xt)dx0dx1dxTp(xtxt1)p(x1x0)p~(x0)logq(xt1xt)dx0dx1dxtp(xtxt1)p(xt1,xt2,...,x1x0)p~(x0)logq(xt1xt)dx0dxt1dxtp(xtxt1)p(xt1x0)p~(x0)logq(xt1xt)dx0dxt1dxtp(xtxt1)p(xt1x0)[Ex0p~(x0)logq(xt1xt)]dxt1dxt其中第一个等号是因为 q ( x t − 1 ∣ x t ) q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) q(xt1xt) 至多依赖到 x t \boldsymbol{x}_t xt,因此 t + 1 t+1 t+1 T T T 的分布可以直接积分为 1;第三个等号则是因为 q ( x t − 1 ∣ x t ) q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) q(xt1xt) 也不依赖于 x 1 , ⋯   , x t − 2 \boldsymbol{x}_1,\cdots,\boldsymbol{x}_{t-2} x1,,xt2,所以关于它们的积分我们也可以事先算出。此外, p ( x t − 1 ∣ x 0 ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0) p(xt1x0) 也可以直接写为 p ( x t − 1 ∣ x 0 ) = N ( x t − 1 ; α ˉ t − 1 x 0 , β ˉ t − 1 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_{t-1};\bar{\alpha}_{t-1} \boldsymbol{x}_0, \bar{\beta}_{t-1}^2 \boldsymbol{I}) p(xt1x0)=N(xt1;αˉt1x0,βˉt12I)
  • 此外,由于拥有可训练参数的就只有 q ( x t − 1 ∣ x t ) q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) q(xt1xt),因此损失函数等价于
    E x 0 ∼ p ~ ( x 0 ) − log ⁡ q ( x t − 1 ∣ x t ) \mathbb E_{\boldsymbol{x}_0\sim\tilde{p}(\boldsymbol{x}_0)}- \log q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) Ex0p~(x0)logq(xt1xt)

场景再现

  • 正态分布为 N ( x ∣ μ , Σ ) = 1 ( 2 π ) D / 2 1 ∣ Σ ∣ 1 / 2 exp ⁡ { − 1 2 ( x − μ ) T Σ − 1 ( x − μ ) } \mathcal{N}(\mathrm{x} \mid \boldsymbol{\mu}, \boldsymbol{\Sigma})=\frac{1}{(2 \pi)^{D / 2}} \frac{1}{|\boldsymbol{\Sigma}|^{1 / 2}} \exp \left\{-\frac{1}{2}(\mathrm{x}-\boldsymbol{\mu})^{\mathrm{T}} \boldsymbol{\Sigma}^{-1}(\mathrm{x}-\boldsymbol{\mu})\right\} N(xμ,Σ)=(2π)D/21Σ1/21exp{21(xμ)TΣ1(xμ)},因此除去优化无关的常数,有
    − log ⁡ q ( x t − 1 ∣ x t ) ∼ 1 2 σ t 2 ∥ x t − 1 − μ ( x t ) ∥ 2 -\log q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)\sim \frac{1}{2\sigma_t^2}\left\Vert\boldsymbol{x}_{t-1} - \boldsymbol{\mu}(\boldsymbol{x}_t)\right\Vert^2 logq(xt1xt)2σt21xt1μ(xt)2 x t − 1 = 1 α t ( x t − β t ε t ) \boldsymbol{x}_{t-1} = \frac{1}{\alpha_t}\left(\boldsymbol{x}_t - \beta_t \boldsymbol{\varepsilon}_t\right) xt1=αt1(xtβtεt) 启发,将 μ ( x t ) \boldsymbol{\mu}(\boldsymbol{x}_t) μ(xt) 写为 μ ( x t ) = 1 α t ( x t − β t ϵ θ ( x t , t ) ) \boldsymbol{\mu}(\boldsymbol{x}_t) = \frac{1}{\alpha_t}\left(\boldsymbol{x}_t - \beta_t \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right) μ(xt)=αt1(xtβtϵθ(xt,t)) 的形式,可得
    − log ⁡ q ( x t − 1 ∣ x t ) ∼ 1 σ t 2 ∥ x t − 1 − μ ( x t ) ∥ 2 = β t 2 α t 2 σ t 2 ∥ ε t − ϵ θ ( x t , t ) ∥ 2 = β t 2 α t 2 σ t 2 ∥ ε t − ϵ θ ( α ˉ t x 0 + α t β ˉ t − 1 ε ˉ t − 1 + β t ε t , t ) ∥ 2 \begin{aligned} -\log q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) &\sim \frac{1}{\sigma_t^2}\left\Vert\boldsymbol{x}_{t-1} - \boldsymbol{\mu}(\boldsymbol{x}_t)\right\Vert^2 \\&=\frac{\beta_t^2}{\alpha_t^2\sigma_t^2}\left\Vert\boldsymbol{\varepsilon}_t -\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right\Vert^2 \\&=\frac{\beta_t^2}{\alpha_t^2\sigma_t^2}\left\Vert \boldsymbol{\varepsilon}_t - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bar{\alpha}_t\boldsymbol{x}_0 + \alpha_t\bar{\beta}_{t-1}\bar{\boldsymbol{\varepsilon}}_{t-1} + \beta_t \boldsymbol{\varepsilon}_t, t)\right\Vert^2 \end{aligned} logq(xt1xt)σt21xt1μ(xt)2=αt2σt2βt2εtϵθ(xt,t)2=αt2σt2βt2εtϵθ(αˉtx0+αtβˉt1εˉt1+βtεt,t)2再按照 “降低方差” 一节做换元,可得
    − log ⁡ q ( x t − 1 ∣ x t ) ∼ β t 4 β ˉ t 2 α t 2 σ t 2 ∥ ε − β ˉ t β t ϵ θ ( α ˉ t x 0 + β ˉ t ε , t ) ∥ 2 \begin{aligned} -\log q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) &\sim\frac{\beta_t^4}{\bar\beta_t^2\alpha_t^2\sigma_t^2}\left\Vert\boldsymbol{\varepsilon} - \frac{\bar{\beta}_t}{\beta_t}\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bar{\alpha}_t\boldsymbol{x}_0 + \bar{\beta}_t\boldsymbol{\varepsilon}, t)\right\Vert^2 \end{aligned} logq(xt1xt)βˉt2αt2σt2βt4εβtβˉtϵθ(αˉtx0+βˉtε,t)2
  • 现在损失函数可以写为
    β t 4 β ˉ t 2 α t 2 σ t 2 E ε ∼ N ( 0 , I ) , x 0 ∼ p ~ ( x 0 ) [ ∥ ε − β ˉ t β t ϵ θ ( α ˉ t x 0 + β ˉ t ε , t ) ∥ 2 ] \frac{\beta_t^4}{\bar{\beta}_t^2\alpha_t^2\sigma_t^2}\mathbb{E}_{\boldsymbol{\varepsilon}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I}),\boldsymbol{x}_0\sim \tilde{p}(\boldsymbol{x}_0)}\left[\left\Vert\boldsymbol{\varepsilon} - \frac{\bar{\beta}_t}{\beta_t}\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bar{\alpha}_t\boldsymbol{x}_0 + \bar{\beta}_t\boldsymbol{\varepsilon}, t)\right\Vert^2\right] βˉt2αt2σt2βt4EεN(0,I),x0p~(x0)[εβtβˉtϵθ(αˉtx0+βˉtε,t)2]这就得到了 DDPM 的训练目标了(原论文通过实验发现,去掉上式前面的系数后实际效果更好些

超参设置

  • 从最小化两个联合分布的 KL 散度出发,我们可以进一步理解为什么要设计适当的 α t \alpha_t αt 使得 α ˉ T ≈ 0 \bar{\alpha}_T\approx 0 αˉT0. 前面说了, q ( x T ) q(\boldsymbol{x}_T) q(xT) 一般都取标准正态分布 N ( x T ; 0 , I ) \mathcal{N}(\boldsymbol{x}_T;\boldsymbol{0}, \boldsymbol{I}) N(xT;0,I)。而我们的学习目标是最小化两个联合分布的 KL 散度,即希望 p = q p=q p=q,那么它们的边缘分布自然也相等,所以我们也希望
    q ( x T ) = ∫ p ( x T ∣ x 0 ) p ~ ( x 0 ) d x 0 q(\boldsymbol{x}_T) = \int p(\boldsymbol{x}_T|\boldsymbol{x}_0) \tilde{p}(\boldsymbol{x}_0) d\boldsymbol{x}_0 q(xT)=p(xTx0)p~(x0)dx0由于数据分布 p ~ ( x 0 ) \tilde{p}(\boldsymbol{x}_0) p~(x0) 是任意的,所以要使上式恒成立,只能让 p ( x T ∣ x 0 ) = q ( x T ) p(\boldsymbol{x}_T|\boldsymbol{x}_0)=q(\boldsymbol{x}_T) p(xTx0)=q(xT),即退化为与 x 0 \boldsymbol{x}_0 x0 无关的标准正态分布

References

  • 苏剑林. (Jul. 06, 2022). 《生成扩散模型漫谈(二):DDPM = 自回归式 VAE 》[Blog post]. Retrieved from https://kexue.fm/archives/9152

你可能感兴趣的:(#,Generative,Models,diffusion,model)