深度学习(生成式模型)——Classifier Guidance Diffusion

文章目录

  • 前言
  • 问题建模
  • 条件扩散模型的前向过程
  • 条件扩散模型的反向过程
  • 条件扩散模型的训练目标

前言

几乎所有的生成式模型,发展到后期都需要引入"控制"的概念,可控制的生成式模型才能更好应用于实际场景。本文将总结《Diffusion Models Beat GANs on Image Synthesis》中提出的Classifier Guidance Diffusion(即条件扩散模型),其往Diffusion Model中引入了控制的概念,可以控制DDPM、DDIM生成指定类别(条件)的图片。

问题建模

本章节所有符号定义与DDPM一致,在条件 y y y下的Diffusion Model的前向与反向过程可以定义为
q ^ ( x t + 1 ∣ x t , y ) q ^ ( x t ∣ x t + 1 , y ) \begin{aligned} \hat q(x_{t+1}|x_{t},y)\\ \hat q(x_t|x_{t+1},y) \end{aligned} q^(xt+1xt,y)q^(xtxt+1,y)
只要求出上述两个概率密度函数,我们即可按条件生成图像。

我们利用 q ^ \hat q q^表示条件扩散模型的概率密度函数, q q q表示扩散模型的概率密度函数。

条件扩散模型的前向过程

对于前向过程,作者定义了以下等式
q ^ ( x 0 ) = q ( x 0 ) q ^ ( x t + 1 ∣ x t , y ) = q ( x t + 1 ∣ x t ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) \begin{aligned} \hat q(x_0)&=q(x_0)\\ \hat q(x_{t+1}|x_t,y)&=q(x_{t+1}|x_t)\\ \hat q(x_{1:T}|x_0,y)&=\prod_{t=1}^T\hat q(x_t|x_{t-1},y) \end{aligned} q^(x0)q^(xt+1xt,y)q^(x1:Tx0,y)=q(x0)=q(xt+1xt)=t=1Tq^(xtxt1,y)

基于上述第二行定义,可知基于条件 y y y的diffusion model的前向过程与普通的diffusion model一致,即 q ^ ( x t + 1 ∣ x t ) = q ( x t + 1 ∣ x t ) \hat q(x_{t+1}|x_t)=q(x_{t+1}|x_t) q^(xt+1xt)=q(xt+1xt)。即加噪过程与条件 y y y无关,这种定义也是合理的。

条件扩散模型的反向过程

对于反向过程,我们有
q ^ ( x t ∣ x t + 1 , y ) = q ^ ( x t , x t + 1 , y ) q ^ ( x t + 1 , y ) = q ^ ( x t , x t + 1 , y ) q ^ ( y ∣ x t + 1 ) q ^ ( x t + 1 ) = q ^ ( x t , y ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) = q ^ ( y ∣ x t , x t + 1 ) q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) (1.0) \begin{aligned} \hat q(x_t|x_{t+1},y)&=\frac{\hat q(x_t,x_{t+1},y)}{\hat q(x_{t+1},y)}\\ &=\frac{\hat q(x_t,x_{t+1},y)}{\hat q(y|x_{t+1})\hat q(x_{t+1})}\\ &=\frac{\hat q(x_t,y|x_{t+1})}{\hat q(y|x_{t+1})}\\ &=\frac{\hat q(y|x_t,x_{t+1})\hat q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})} \end{aligned}\tag{1.0} q^(xtxt+1,y)=q^(xt+1,y)q^(xt,xt+1,y)=q^(yxt+1)q^(xt+1)q^(xt,xt+1,y)=q^(yxt+1)q^(xt,yxt+1)=q^(yxt+1)q^(yxt,xt+1)q^(xtxt+1)(1.0)

已知条件扩散模型的前向过程与扩散模型一致,则有

q ^ ( x 1 : T ∣ x 0 ) = q ( x 1 : T ∣ x 0 ) \hat q(x_{1:T}|x_0)=q(x_{1:T}|x_0) q^(x1:Tx0)=q(x1:Tx0)

进而有
q ^ ( x t ) = ∫ q ^ ( x 0 , . . . , x t ) d x 0 : t − 1 = ∫ q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ q ( x 0 ) q ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = q ( x t ) \begin{aligned} \hat q(x_{t})&=\int \hat q(x_0,...,x_t) dx_{0:t-1}\\ &=\int \hat q(x_0)\hat q(x_{1:t}|x_0)dx_{0:t-1}\\ &=\int q(x_0)q(x_{1:t}|x_0)dx_{0:t-1}\\ &=q(x_t) \end{aligned} q^(xt)=q^(x0,...,xt)dx0:t1=q^(x0)q^(x1:tx0)dx0:t1=q(x0)q(x1:tx0)dx0:t1=q(xt)

对于 q ^ ( x t ∣ x t + 1 ) \hat q(x_t|x_{t+1}) q^(xtxt+1),则有
q ^ ( x t ∣ x t + 1 ) = q ^ ( x t , x t + 1 ) q ^ ( x t + 1 ) = q ^ ( x t + 1 ∣ x t ) q ^ ( x t ) q ^ ( x t + 1 ) = q ( x t + 1 ∣ x t ) q ( x t ) q ( x t + 1 ) = q ( x t ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1})&=\frac{\hat q(x_t,x_{t+1})}{\hat q(x_{t+1})}\\ &=\frac{\hat q(x_{t+1}|x_t)\hat q(x_{t})}{\hat q(x_{t+1})}\\ &=\frac{q(x_{t+1}|x_t)q(x_{t})}{q(x_{t+1})}\\ &=q(x_{t}|x_{t+1}) \end{aligned} q^(xtxt+1)=q^(xt+1)q^(xt,xt+1)=q^(xt+1)q^(xt+1xt)q^(xt)=q(xt+1)q(xt+1xt)q(xt)=q(xtxt+1)

对于 q ^ ( y ∣ x t , x x t + 1 ) \hat q(y|x_t,x_{x_{t+1}}) q^(yxt,xxt+1),我们有
q ^ ( y ∣ x t , x x t + 1 ) = q ^ ( x t + 1 ∣ x t , y ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( y ∣ x t ) \begin{aligned} \hat q(y|x_t,x_{x_{t+1}})&=\frac{\hat q(x_{t+1}|x_t,y)\hat q(y|x_t)}{\hat q(x_{t+1}|x_t)}\\ &=\frac{\hat q(x_{t+1}|x_t)\hat q(y|x_t)}{\hat q(x_{t+1}|x_t)}\\ &=\hat q(y|x_t) \end{aligned} q^(yxt,xxt+1)=q^(xt+1xt)q^(xt+1xt,y)q^(yxt)=q^(xt+1xt)q^(xt+1xt)q^(yxt)=q^(yxt)

因此式1.0为

q ^ ( x t ∣ x t + 1 , y ) = q ^ ( y ∣ x t , x t + 1 ) q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) = q ^ ( y ∣ x t ) q ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1},y)&=\frac{\hat q(y|x_t,x_{t+1})\hat q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})}\\ &=\frac{\hat q(y|x_t)q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})} \end{aligned} q^(xtxt+1,y)=q^(yxt+1)q^(yxt,xt+1)q^(xtxt+1)=q^(yxt+1)q^(yxt)q(xtxt+1)

由于在反向过程中, x t + 1 x_{t+1} xt+1是已知的,因此 q ^ ( y ∣ x t + 1 ) \hat q(y|x_{t+1}) q^(yxt+1)也可看成已知值,设其倒数为 Z Z Z,则有

q ^ ( x t ∣ x t + 1 , y ) = Z q ^ ( y ∣ x t ) q ( x t ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1},y) = Z\hat q(y|x_t)q(x_{t}|x_{t+1}) \end{aligned} q^(xtxt+1,y)=Zq^(yxt)q(xtxt+1)

取log可得
log ⁡ q ^ ( x t ∣ x t + 1 , y ) = log ⁡ Z + log ⁡ q ^ ( y ∣ x t ) + log ⁡ q ^ ( x t ∣ x t + 1 ) (1.1) \begin{aligned} \log \hat q(x_{t}|x_{t+1},y)=\log Z+\log \hat q(y|x_t)+\log \hat q(x_t|x_{t+1})\tag{1.1} \end{aligned} logq^(xtxt+1,y)=logZ+logq^(yxt)+logq^(xtxt+1)(1.1)

q ^ ( x t ∣ x t + 1 ) = N ( μ t , ∑ t 2 ) \hat q(x_t|x_{t+1})=\mathcal N(\mu_t,\sum_t^2) q^(xtxt+1)=N(μt,t2),则有
log ⁡ q ^ ( x t ∣ x t + 1 ) = − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C (1.2) \log \hat q(x_{t}|x_{t+1})=-\frac{1}{2}(x_t-\mu_t)^T({\sum}_t)^{-1}(x_t-\mu_t)+C\tag{1.2} logq^(xtxt+1)=21(xtμt)T(t)1(xtμt)+C(1.2)

对于 log ⁡ q ^ ( y ∣ x t ) \log \hat q(y|x_t) logq^(yxt),在 x t = μ t x_t=\mu_t xt=μt处做泰勒展开,则有

log ⁡ q ^ ( y ∣ x t ) ≈ log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t + ( x t − μ t ) ∇ x t log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t = C 1 + ( x t − μ t ) g (1.3) \begin{aligned} \log \hat q(y|x_t) &\approx \log \hat q(y|x_t)|_{x_t=\mu_t}+(x_t-\mu_t)\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t}\\ &=C_1+(x_t-\mu_t)g \end{aligned}\tag{1.3} logq^(yxt)logq^(yxt)xt=μt+(xtμt)xtlogq^(yxt)xt=μt=C1+(xtμt)g(1.3)
其中 g = ∇ x t log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t g=\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t} g=xtlogq^(yxt)xt=μt,结合式1.1、1.2、1.3,有

log ⁡ q ^ ( x t ∣ x t + 1 , y ) ≈ C 1 + ( x t − μ t ) g + log ⁡ Z − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C = ( x t − μ t ) g − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C 2 = − 1 2 ( x t − μ t − ∑ t g ) T ( ∑ t ) − 1 ( x t − μ t − ∑ t g ) + C 3 \begin{aligned} \log \hat q(x_{t}|x_{t+1},y)&\approx C_1+(x_t-\mu_t)g+\log Z-\frac{1}{2}(x_t-\mu_t)^T(\sum{_t})^{-1}(x_t-\mu_t)+C\\ &=(x_t-\mu_t)g-\frac{1}{2}(x_t-\mu_t)^T(\sum{_t})^{-1}(x_t-\mu_t)+C_2\\ &=-\frac{1}{2}(x_t-\mu_t-\sum{_t} g)^T(\sum{_t})^{-1}(x_t-\mu_t-\sum{_t}g)+C_3 \end{aligned} logq^(xtxt+1,y)C1+(xtμt)g+logZ21(xtμt)T(t)1(xtμt)+C=(xtμt)g21(xtμt)T(t)1(xtμt)+C2=21(xtμttg)T(t)1(xtμttg)+C3

最终有

q ^ ( x t ∣ x t + 1 , y ) ≈ N ( μ t + ∑ t g , ( ∑ t ) 2 ) g = ∇ x t log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t (1.4) \begin{aligned} \hat q(x_t|x_{t+1},y)\approx \mathcal N(\mu_t+{\sum}_{t}g,({\sum}_t)^2)\\ g=\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t} \end{aligned}\tag{1.4} q^(xtxt+1,y)N(μt+tg,(t)2)g=xtlogq^(yxt)xt=μt(1.4)

为了获得 ∇ x t log ⁡ q ^ ( y ∣ x t ) \nabla_{x_t}\log\hat q(y|x_t) xtlogq^(yxt),Classifier Guidance Diffusion在训练好的Diffusion model的基础上额外训练了一个分类头。

假设 x t ≈ μ t x_t \approx\mu_t xtμt,则Classifier Guidance Diffusion的反向过程为:
深度学习(生成式模型)——Classifier Guidance Diffusion_第1张图片

其中 p ϕ ( y ∣ x t ) = q ^ ( y ∣ x t ) p_ \phi(y|x_t)=\hat q(y|x_t) pϕ(yxt)=q^(yxt) s s s为一个超参数。

式1.4有个问题,当方差 ∑ \sum 取值为0时, ∑ ∇ x t log ⁡ q ^ ( y ∣ x t ) {\sum}\nabla_{x_t}\log\hat q(y|x_t) xtlogq^(yxt)取值将为0,无法控制生成指定条件的图像。因此式1.4不适用于DDIM等确定性采样的扩散模型

在推导DDIM的采样公式前,我们先了解一下用Tweedie方法做参数估计的流程。

Tweedie方法主要用于指数族概率分布的参数估计,而高斯分布属于指数族概率分布,自然也适用。假设有一批样本 z z z,则利用样本 z z z估计高斯分布 N ( Z ; μ , ∑ 2 ) \mathcal N(Z;\mu,{\sum}^2) N(Z;μ,2)的均值 μ \mu μ的公式为

E [ μ ∣ z ] = z + ∑ 2 ∇ z log ⁡ p ( z ) (1.5) E[\mu|z]=z+{\sum}^2\nabla_z\log p(z)\tag{1.5} E[μz]=z+2zlogp(z)(1.5)

已知DDPM、DDIM的前向过程有

q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) (1.6) q(x_t|x_0)=\mathcal N(x_t;\sqrt{\bar \alpha_t}x_0,(1-\bar\alpha_t)\mathcal I)\tag{1.6} q(xtx0)=N(xt;αˉt x0,(1αˉt)I)(1.6)

结合式1.5、1.6可得

α ˉ t x 0 = x t + ( 1 − α ˉ t ) ∇ x t log ⁡ p ( x t ) \begin{aligned} \sqrt{\bar \alpha_t}x_0=x_t+(1-\bar\alpha_t)\nabla_{x_t}\log p(x_t) \end{aligned} αˉt x0=xt+(1αˉt)xtlogp(xt)
进而有
x t = α ˉ t x 0 − ( 1 − α ˉ t ) ∇ x t log ⁡ p ( x t ) (1.7) x_t=\sqrt{\bar \alpha_t}x_0-(1-\bar\alpha_t)\nabla_{x_t}\log p(x_t)\tag{1.7} xt=αˉt x0(1αˉt)xtlogp(xt)(1.7)
ϵ t \epsilon_t ϵt服从标准正态分布,则从式1.6可知

x t = α ˉ t x 0 + 1 − α ˉ t ϵ t (1.8) x_t=\sqrt{\bar \alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon_t\tag{1.8} xt=αˉt x0+1αˉt ϵt(1.8)

结合式1.7、1.8,则有

∇ x t log ⁡ p ( x t ) = − 1 1 − α ˉ t ϵ t (1.9) \nabla_{x_t}\log p(x_t)=-\frac{1}{\sqrt{1-\bar\alpha_t}}\epsilon_t\tag{1.9} xtlogp(xt)=1αˉt 1ϵt(1.9)

已知DDIM的采样公式为

x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ϵ θ ( x t ) α ˉ t + 1 − α ˉ t − δ t 2 ϵ θ ( x t ) (2.0) x_{t-1}=\sqrt{\bar \alpha_{t-1}}\frac{x_t-\sqrt{1-\bar \alpha_t}\epsilon_\theta(x_t)}{\sqrt{\bar\alpha_t}}+\sqrt{1-\bar\alpha_{t}-\delta_t^2}\epsilon_\theta(x_t)\tag{2.0} xt1=αˉt1 αˉt xt1αˉt ϵθ(xt)+1αˉtδt2 ϵθ(xt)(2.0)

结合式1.9、2.0可将DDIM的采样公式转变为

x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ( − 1 − α ˉ t ∇ x t log ⁡ p ( x t ) ) α ˉ t + 1 − α ˉ t − δ t 2 ( − 1 − α ˉ t ∇ x t log ⁡ p ( x t ) ) (2.1) x_{t-1}=\sqrt{\bar \alpha_{t-1}}\frac{x_t-\sqrt{1-\bar \alpha_t}(-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t))}{\sqrt{\bar\alpha_t}}+\sqrt{1-\bar\alpha_{t}-\delta_t^2}(-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t))\tag{2.1} xt1=αˉt1 αˉt xt1αˉt (1αˉt xtlogp(xt))+1αˉtδt2 (1αˉt xtlogp(xt))(2.1)

我们只需要将其中的 ∇ x t log ⁡ p ( x t ) \nabla_{x_t}\log p(x_t) xtlogp(xt)替换为 ∇ x t log ⁡ p ( x t ∣ y ) \nabla_{x_t}\log p(x_t|y) xtlogp(xty),即可引入条件 y y y来控制DDIM的生成过程,利用贝叶斯定理,我们有

log ⁡ p ( x t ∣ y ) = log ⁡ p ( y ∣ x t ) + log ⁡ p ( x t ) − log ⁡ p ( y ) ∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) − ∇ x t log ⁡ p ( y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) = ∇ x t log ⁡ p ( y ∣ x t ) − 1 1 − α ˉ t ϵ t (2.2) \begin{aligned} \log p(x_t|y)&=\log p(y|x_t)+\log p(x_t)-\log p(y)\\ \nabla_{x_t}\log p(x_t|y)&=\nabla_{x_t}\log p(y|x_t)+\nabla_{x_t}\log p(x_t)-\nabla_{x_t}\log p(y)\\ &=\nabla_{x_t}\log p(y|x_t)+\nabla_{x_t}\log p(x_t)\\ &=\nabla_{x_t}\log p(y|x_t)-\frac{1}{\sqrt{1-\bar\alpha_t}}\epsilon_t \end{aligned}\tag{2.2} logp(xty)xtlogp(xty)=logp(yxt)+logp(xt)logp(y)=xtlogp(yxt)+xtlogp(xt)xtlogp(y)=xtlogp(yxt)+xtlogp(xt)=xtlogp(yxt)1αˉt 1ϵt(2.2)
则有

− 1 − α ˉ t ∇ x t log ⁡ p ( x t ∣ y ) = ϵ t − 1 − α ˉ t ∇ x t log ⁡ p ( y ∣ x t ) (2.3) -\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t|y)=\epsilon_t-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(y|x_t)\tag{2.3} 1αˉt xtlogp(xty)=ϵt1αˉt xtlogp(yxt)(2.3)

至此,我们可以得到DDIM的采样流程为
在这里插入图片描述
对于DDIM等确定性采样的扩散模型,其应在训练好的Diffusion model的基础上额外训练了一个分类头,从而转变为Classifier Guidance Diffusion。

条件扩散模型的训练目标

注意到 q ^ ( x t ∣ x t + 1 ) = q ( x t ∣ x t + 1 ) \hat q(x_t|x_{t+1})=q(x_t|x_{t+1}) q^(xtxt+1)=q(xtxt+1),并且上述的推导过程并没有改变 q ( x t ∣ x t + 1 ) 、 q ( x t + 1 ∣ x t ) q(x_t|x_{t+1})、q(x_{t+1}|x_t) q(xtxt+1)q(xt+1xt)的形式,因此Classifier Guidance Diffusion的训练目标与DDPM、DDIM是一致的,都可以拟合训练数据。

你可能感兴趣的:(深度学习,深度学习,人工智能,AIGC,aigc)