DDPM = 贝叶斯 + 去噪

Contents

  • 请贝叶斯
  • 去噪过程
  • 预估修正
  • Random Sample - 方差选取
  • References

  • 前面两篇文章给出了 DDPM 的两种推导,“DDPM = 拆楼 + 建楼” 更为直白易懂,但无法做更多的理论延伸和定量理解,“DDPM = 自回归式 VAE” 理论分析上更加完备一些,但稍显形式化,启发性不足。下面再分享 DDPM 的一种推导,它主要利用到了贝叶斯定理来简化计算,整个过程的 “推敲” 味道颇浓,很有启发性。不仅如此,它还跟 DDIM 模型有着紧密的联系

请贝叶斯

  • 利用贝叶斯公式,理论上我们想要获得如下生成过程 p ( x t − 1 ∣ x t ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) p(xt1xt) 的表示
    p ( x t − 1 ∣ x t ) = p ( x t ∣ x t − 1 ) p ( x t − 1 ) p ( x t ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) = \frac{p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1})}{p(\boldsymbol{x}_t)} p(xt1xt)=p(xt)p(xtxt1)p(xt1)然而,我们并不知道 p ( x t − 1 ) , p ( x t ) p(\boldsymbol{x}_{t-1}),p(\boldsymbol{x}_t) p(xt1),p(xt), p ( x t − 1 ) , p ( x t ) p(\boldsymbol{x}_{t-1}),p(\boldsymbol{x}_t) p(xt1),p(xt) 的表达式,所以此路不通。但我们可以退而求其次,在给定 x 0 \boldsymbol{x}_0 x0 的条件下使用贝叶斯定理
    p ( x t − 1 ∣ x t , x 0 ) = p ( x t ∣ x t − 1 ) p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0) = \frac{p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)}{p(\boldsymbol{x}_t|\boldsymbol{x}_0)} p(xt1xt,x0)=p(xtx0)p(xtxt1)p(xt1x0)其中 p ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , β t 2 I ) p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})=\mathcal{N}(\boldsymbol{x}_t;\alpha_t \boldsymbol{x}_{t-1}, \beta_t^2 \boldsymbol{I}) p(xtxt1)=N(xt;αtxt1,βt2I), p ( x t − 1 ∣ x 0 ) = N ( x t − 1 ; α ˉ t − 1 x 0 , β ˉ t − 1 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_{t-1};\bar{\alpha}_{t-1} \boldsymbol{x}_0, \bar{\beta}_{t-1}^2 \boldsymbol{I}) p(xt1x0)=N(xt1;αˉt1x0,βˉt12I), p ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , β ˉ t 2 I ) p(\boldsymbol{x}_{t}|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_{t};\bar{\alpha}_{t} \boldsymbol{x}_0, \bar{\beta}_{t}^2 \boldsymbol{I}) p(xtx0)=N(xt;αˉtx0,βˉt2I). 代入可得指数部分除掉 − 1 / 2 −1/2 1/2 因子外,结果是:
    ∥ x t − α t x t − 1 ∥ 2 β t 2 + ∥ x t − 1 − α ˉ t − 1 x 0 ∥ 2 β ˉ t − 1 2 − ∥ x t − α ˉ t x 0 ∥ 2 β ˉ t 2 \frac{\Vert \boldsymbol{x}_t - \alpha_t \boldsymbol{x}_{t-1}\Vert^2}{\beta_t^2} + \frac{\Vert \boldsymbol{x}_{t-1} - \bar{\alpha}_{t-1}\boldsymbol{x}_0\Vert^2}{\bar{\beta}_{t-1}^2} - \frac{\Vert \boldsymbol{x}_t - \bar{\alpha}_t \boldsymbol{x}_0\Vert^2}{\bar{\beta}_t^2} βt2xtαtxt12+βˉt12xt1αˉt1x02βˉt2xtαˉtx02它关于 x t − 1 \boldsymbol{x}_{t-1} xt1 是二次的,因此最终的分布必然也是正态分布,我们只需要求出其均值和协方差。不难看出,展开式中 ∥ x t − 1 ∥ 2 \Vert \boldsymbol{x}_{t-1}\Vert^2 xt12 项的系数是
    α t 2 β t 2 + 1 β ˉ t − 1 2 = α t 2 β ˉ t − 1 2 + β t 2 β ˉ t − 1 2 β t 2 = α t 2 ( 1 − α ˉ t − 1 2 ) + β t 2 β ˉ t − 1 2 β t 2 = 1 − α ˉ t 2 β ˉ t − 1 2 β t 2 = β ˉ t 2 β ˉ t − 1 2 β t 2 \frac{\alpha_t^2}{\beta_t^2} + \frac{1}{\bar{\beta}_{t-1}^2} = \frac{\alpha_t^2\bar{\beta}_{t-1}^2 + \beta_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2} = \frac{\alpha_t^2(1-\bar{\alpha}_{t-1}^2) + \beta_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2} = \frac{1-\bar{\alpha}_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2} = \frac{\bar{\beta}_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2} βt2αt2+βˉt121=βˉt12βt2αt2βˉt12+βt2=βˉt12βt2αt2(1αˉt12)+βt2=βˉt12βt21αˉt2=βˉt12βt2βˉt2所以整理好的结果必然是 β ˉ t 2 β ˉ t − 1 2 β t 2 ∥ x t − 1 − μ ~ ( x t , x 0 ) ∥ 2 \frac{\bar{\beta}_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2}\Vert \boldsymbol{x}_{t-1} - \tilde{\boldsymbol{\mu}}(\boldsymbol{x}_t, \boldsymbol{x}_0)\Vert^2 βˉt12βt2βˉt2xt1μ~(xt,x0)2 的形式 (协方差矩阵必然是对角矩阵。此外,由于二次项系数都相同,因此协方差矩阵必为单位矩阵的倍数),这意味着协方差矩阵是 β ˉ t − 1 2 β t 2 β ˉ t 2 I \frac{\bar{\beta}_{t-1}^2 \beta_t^2}{\bar{\beta}_t^2}\boldsymbol{I} βˉt2βˉt12βt2I。另一边,把一次项系数拿出来是 − 2 ( α t β t 2 x t + α ˉ t − 1 β ˉ t − 1 2 x 0 ) -2\left(\frac{\alpha_t}{\beta_t^2}\boldsymbol{x}_t + \frac{\bar{\alpha}_{t-1}}{\bar{\beta}_{t-1}^2}\boldsymbol{x}_0 \right) 2(βt2αtxt+βˉt12αˉt1x0),除以 − 2 β ˉ t 2 β ˉ t − 1 2 β t 2 \frac{-2\bar{\beta}_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2} βˉt12βt22βˉt2 后便可以得到
    μ ~ ( x t , x 0 ) = α t β ˉ t − 1 2 β ˉ t 2 x t + α ˉ t − 1 β t 2 β ˉ t 2 x 0 \tilde{\boldsymbol{\mu}}(\boldsymbol{x}_t, \boldsymbol{x}_0)=\frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t + \frac{\bar{\alpha}_{t-1}\beta_t^2}{\bar{\beta}_t^2}\boldsymbol{x}_0 μ~(xt,x0)=βˉt2αtβˉt12xt+βˉt2αˉt1βt2x0最终得到下式,它可以借助原图像完成对当前图像的去噪
    p ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; α t β ˉ t − 1 2 β ˉ t 2 x t + α ˉ t − 1 β t 2 β ˉ t 2 x 0 , β ˉ t − 1 2 β t 2 β ˉ t 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0) = \mathcal{N}\left(\boldsymbol{x}_{t-1};\frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t + \frac{\bar{\alpha}_{t-1}\beta_t^2}{\bar{\beta}_t^2}\boldsymbol{x}_0,\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} \boldsymbol{I}\right) p(xt1xt,x0)=N(xt1;βˉt2αtβˉt12xt+βˉt2αˉt1βt2x0,βˉt2βˉt12βt2I)

去噪过程

  • 下面我们需要在不借助原图像 x 0 \boldsymbol{x}_0 x0 的前提下完成去噪。一个 “异想天开” 的想法是 μ ˉ ( x t ) \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) μˉ(xt) 来预估 x 0 \boldsymbol{x}_0 x0,损失函数为 ∥ x 0 − μ ˉ ( x t ) ∥ 2 \Vert \boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\Vert^2 x0μˉ(xt)2,这实际上在训练一个去噪模型,这也就是 DDPM 的第一个 “D” 的含义 (Denoising). 由于 x 0 = 1 α ˉ t ( x t − β ˉ t ε ) \boldsymbol{x}_0 = \frac{1}{\bar{\alpha}_t}\left(\boldsymbol{x}_t - \bar{\beta}_t \boldsymbol{\varepsilon}\right) x0=αˉt1(xtβˉtε),因此将 μ ˉ ( x t ) \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) μˉ(xt) 参数化为
    μ ˉ ( x t ) = 1 α ˉ t ( x t − β ˉ t ϵ θ ( x t , t ) ) \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) = \frac{1}{\bar{\alpha}_t}\left(\boldsymbol{x}_t - \bar{\beta}_t \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right) μˉ(xt)=αˉt1(xtβˉtϵθ(xt,t))此时损失函数变为
    ∥ x 0 − μ ˉ ( x t ) ∥ 2 = β ˉ t 2 α ˉ t 2 ∥ ε − ϵ θ ( α ˉ t x 0 + β ˉ t ε , t ) ∥ 2 \Vert \boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\Vert^2 = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\left\Vert\boldsymbol{\varepsilon} - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bar{\alpha}_t \boldsymbol{x}_0 + \bar{\beta}_t \boldsymbol{\varepsilon}, t)\right\Vert^2 x0μˉ(xt)2=αˉt2βˉt2εϵθ(αˉtx0+βˉtε,t)2省去前面的系数,就得到 DDPM 原论文所用的损失函数了 (提示:出于推导的流畅性考虑,这里的 ϵ θ \boldsymbol{\epsilon}_{\boldsymbol{\theta}} ϵθ 跟前两个视角介绍不一样,反而跟 DDPM 原论文一致)。可以发现,这里是直接得出了从 x t \boldsymbol{x}_t xt x 0 \boldsymbol{x}_0 x0 的去噪过程,而不是像之前两个视角那样,通过 x t \boldsymbol{x}_t xt x t − 1 \boldsymbol{x}_{t-1} xt1 的去噪过程再加上积分变换来推导,相比之下这里的推导可谓更加一步到位了
  • 训练完成后,我们就认为
    p ( x t − 1 ∣ x t ) ≈ p ( x t − 1 ∣ x t , x 0 = μ ˉ ( x t ) ) = N ( x t − 1 ; α t β ˉ t − 1 2 β ˉ t 2 x t + α ˉ t − 1 β t 2 β ˉ t 2 μ ˉ ( x t ) , β ˉ t − 1 2 β t 2 β ˉ t 2 I ) = N ( x t − 1 ; 1 α t ( x t − β t 2 β ˉ t ϵ θ ( x t , t ) ) , β ˉ t − 1 2 β t 2 β ˉ t 2 I ) \begin{aligned} p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) &\approx p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0=\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)) \\&= \mathcal{N}\left(\boldsymbol{x}_{t-1}; \frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t + \frac{\bar{\alpha}_{t-1}\beta_t^2}{\bar{\beta}_t^2}\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t),\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} \boldsymbol{I}\right) \\&= \mathcal{N}\left(\boldsymbol{x}_{t-1}; \frac{1}{\alpha_t}\left(\boldsymbol{x}_t - \frac{\beta_t^2}{\bar{\beta}_t}\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right),\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} \boldsymbol{I}\right) \end{aligned} p(xt1xt)p(xt1xt,x0=μˉ(xt))=N(xt1;βˉt2αtβˉt12xt+βˉt2αˉt1βt2μˉ(xt),βˉt2βˉt12βt2I)=N(xt1;αt1(xtβˉtβt2ϵθ(xt,t)),βˉt2βˉt12βt2I)这就是反向的采样过程所用的分布,连同采样过程所用的方差也一并确定下来了

预估修正

  • 不知道读者有没有留意到一个有趣的地方:我们要做的事情,就是想将 x T \boldsymbol{x}_T xT 慢慢地变为 x 0 \boldsymbol{x}_0 x0,而我们在借用 p ( x t − 1 ∣ x t , x 0 ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0) p(xt1xt,x0) 近似 p ( x t − 1 ∣ x t ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) p(xt1xt) 时,却包含了 “用 μ ˉ ( x t ) \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) μˉ(xt) 来预估 x 0 \boldsymbol{x}_0 x0” 这一步,要是能预估准的话,那就直接一步到位了,还需要逐步采样吗?
  • 真实情况是,“用 μ ˉ ( x t ) \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) μˉ(xt) 来预估 x 0 \boldsymbol{x}_0 x0” 当然不会太准的,至少开始的相当多步内不会太准。它仅仅起到了一个前瞻性的预估作用,然后我们只用 p ( x t − 1 ∣ x t ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) p(xt1xt) 来推进一小步,这就是很多数值算法中的 “预估-修正” 思想,即我们用一个粗糙的解往前推很多步,然后利用这个粗糙的结果将最终结果推进一小步,以此来逐步获得更为精细的解

Random Sample - 方差选取

  • (1) 假设整个数据集只有一个样本,不失一般性,假设该样本为 0 \boldsymbol{0} 0,此时 p ~ ( x 0 ) \tilde{p}(\boldsymbol{x}_0) p~(x0)狄拉克分布 δ ( x 0 ) \delta(\boldsymbol{x}_0) δ(x0),可以直接算出 p ( x t ) = p ( x t ∣ 0 ) p(\boldsymbol{x}_t)=p(\boldsymbol{x}_t|\boldsymbol{0}) p(xt)=p(xt0)。代入下式
    p ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; α t β ˉ t − 1 2 β ˉ t 2 x t + α ˉ t − 1 β t 2 β ˉ t 2 x 0 , β ˉ t − 1 2 β t 2 β ˉ t 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0) = \mathcal{N}\left(\boldsymbol{x}_{t-1};\frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t + \frac{\bar{\alpha}_{t-1}\beta_t^2}{\bar{\beta}_t^2}\boldsymbol{x}_0,\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} \boldsymbol{I}\right) p(xt1xt,x0)=N(xt1;βˉt2αtβˉt12xt+βˉt2αˉt1βt2x0,βˉt2βˉt12βt2I)
    p ( x t − 1 ∣ x t ) = p ( x t − 1 ∣ x t , x 0 = 0 ) = N ( x t − 1 ; α t β ˉ t − 1 2 β ˉ t 2 x t , β ˉ t − 1 2 β t 2 β ˉ t 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) = p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0=\boldsymbol{0}) = \mathcal{N}\left(\boldsymbol{x}_{t-1};\frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t,\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} \boldsymbol{I}\right) p(xt1xt)=p(xt1xt,x0=0)=N(xt1;βˉt2αtβˉt12xt,βˉt2βˉt12βt2I)我们主要关心其方差为 β ˉ t − 1 2 β t 2 β ˉ t 2 \frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} βˉt2βˉt12βt2,这便是采样方差的选择之一
  • (2) 假设数据集服从标准正态分布,即 p ~ ( x 0 ) = N ( x 0 ; 0 , I ) \tilde{p}(\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_0;\boldsymbol{0},\boldsymbol{I}) p~(x0)=N(x0;0,I)。由于 x t = α ˉ t x 0 + β ˉ t ε , ε ∼ N ( 0 , I ) \boldsymbol{x}_t = \bar{\alpha}_t \boldsymbol{x}_0 + \bar{\beta}_t \boldsymbol{\varepsilon},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I}) xt=αˉtx0+βˉtε,εN(0,I) x 0 ∼ N ( 0 , I ) \boldsymbol{x}_0\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I}) x0N(0,I),所以由正态分布的叠加性, x t \boldsymbol{x}_t xt 正好也服从标准正态分布。现在有 p ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , β t 2 I ) p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})=\mathcal{N}(\boldsymbol{x}_t;\alpha_t \boldsymbol{x}_{t-1}, \beta_t^2 \boldsymbol{I}) p(xtxt1)=N(xt;αtxt1,βt2I), p ( x t − 1 ∣ x 0 ) = N ( x t − 1 ; 0 , I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_{t-1};0, \boldsymbol{I}) p(xt1x0)=N(xt1;0,I), p ( x t ∣ x 0 ) = N ( x t ; 0 , I ) p(\boldsymbol{x}_{t}|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_{t};0, \boldsymbol{I}) p(xtx0)=N(xt;0,I). 将标准正态分布的概率密度代入 p ( x t − 1 ∣ x t , x 0 ) = p ( x t ∣ x t − 1 ) p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0) = \frac{p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)}{p(\boldsymbol{x}_t|\boldsymbol{x}_0)} p(xt1xt,x0)=p(xtx0)p(xtxt1)p(xt1x0), 结果的指数部分除掉 − 1 / 2 −1/2 1/2 因子外,结果是:
    ∥ x t − α t x t − 1 ∥ 2 β t 2 + ∥ x t − 1 ∥ 2 − ∥ x t ∥ 2 \frac{\Vert \boldsymbol{x}_t - \alpha_t \boldsymbol{x}_{t-1}\Vert^2}{\beta_t^2} + \Vert \boldsymbol{x}_{t-1}\Vert^2 - \Vert \boldsymbol{x}_t\Vert^2 βt2xtαtxt12+xt12xt2跟推导 p ( x t − 1 ∣ x t , x 0 ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0) p(xt1xt,x0) 的过程类似,可以得到上述指数对应于
    p ( x t − 1 ∣ x t ) = N ( x t − 1 ; α t x t , β t 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) = \mathcal{N}\left(\boldsymbol{x}_{t-1};\alpha_t\boldsymbol{x}_t,\beta_t^2 \boldsymbol{I}\right) p(xt1xt)=N(xt1;αtxt,βt2I)我们同样主要关心其方差为 β t 2 \beta_t^2 βt2,这便是采样方差的另一个选择

References

  • 苏剑林. (Jul. 19, 2022). 《生成扩散模型漫谈(三):DDPM = 贝叶斯 + 去噪 》[Blog post]. Retrieved from https://kexue.fm/archives/9164

你可能感兴趣的:(#,Generative,Models,diffusion,model)