diffusion model(三)—— classifier guided diffusion model

系列阅读

  • diffusion model(一)DDPM技术小结 (denoising diffusion probabilistic)
  • diffusion model(二)—— DDIM技术小结
  • diffusion model(三)—— classifier guided diffusion model
  • diffusion model(四)文生图diffusion model(classifier-free guided)
  • diffusion model(五)stable diffusion底层原理(latent diffusion model, LDM40779727/article/details/131405182?spm=1001.2014.3001.5501)

目录

    • 系列阅读
    • 背景
    • 方法大意
      • 基于条件的去噪过程
        • DDIM 中基于条件的去噪过程
      • 一些细节
        • classifier的训练
        • gradient score的作用
    • 参考文献
    • 附录

背景

对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,如何生成特定类别的图片呢?这就是classifier guide需要解决的问题。

方法大意

为了实现带类别标签 y y y的DM的推导,进行了以下定义
q ^ ( x 0 ) : = q ( x 0 ) q ^ ( y ∣ x 0 ) : = Know labels per sample q ^ ( x t + 1 ∣ x t , y ) : = q ( x t + 1 ∣ x t ) q ^ ( x 1 : T ∣ x 0 , y ) : = ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) (1) \begin{aligned} \hat{q}(x_0) &:= q(x_0) \\ \hat{q}(y|x_0) &:= \text{Know labels per sample} \\ \hat{q}(x_{t+1}|x_{t}, y) &:= q(x_{t+1}|x_t) \\ \hat{q}(x_{1:T}|x_0, y)&:= \prod \limits_{t=1}^T\hat{q}(x_t|x_{t-1}, y) \\ \end{aligned} \tag{1} q^(x0)q^(yx0)q^(xt+1xt,y)q^(x1:Tx0,y):=q(x0):=Know labels per sample:=q(xt+1xt):=t=1Tq^(xtxt1,y)(1)
虽然上式定义了以 y y y为条件的噪声过程 q ^ \hat{q} q^,但我们还可以证明当 q ^ \hat{q} q^不以 y y y为条件时的行为与 q q q完全相同,即
q ^ ( x t + 1 ∣ x t ) = ∫ y q ^ ( x t + 1 , y ∣ x t ) d y = ∫ y q ^ ( x t + 1 ∣ x t , y ) q ^ ( y ∣ x t ) d y = ∫ y q ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) d y = q ( x t + 1 ∣ x t ) ∫ y q ^ ( y ∣ x t ) d y = q ( x t + 1 ∣ x t ) = q ^ ( x t + 1 ∣ x t , y ) (2) \begin{aligned} \hat{q}(x_{t+1}|x_t) &= \int_y \hat{q}(x_{t+1}, y| x_t)dy \\ &= \int_y \hat{q}(x_{t+1}|x_t, y)\hat{q}(y|x_t)dy \\ &= \int_y q(x_{t+1}|x_t)\hat{q}(y|x_t)dy \\ &= q(x_{t+1}|x_t) \int_y \hat{q}(y|x_t)dy \\ &= q(x_{t+1}|x_t) \\ &= \hat{q}(x_{t+1}|x_t, y) \\ \end{aligned}\tag{2} q^(xt+1xt)=yq^(xt+1,yxt)dy=yq^(xt+1xt,y)q^(yxt)dy=yq(xt+1xt)q^(yxt)dy=q(xt+1xt)yq^(yxt)dy=q(xt+1xt)=q^(xt+1xt,y)(2)
同样的思路:
q ^ ( x 1 : T ∣ x 0 ) = ∫ y q ^ ( x 1 : T , y ∣ x 0 ) d y = ∫ y q ^ ( x 1 : T ∣ y , x 0 ) q ( y ∣ x 0 ) d y = ∫ y ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) ⏟ q ( x t ∣ x t − 1 ) q ( y ∣ x 0 ) d y = ∏ t = 1 T q ( x t ∣ x t − 1 ) ⏟ q ( x 1 : T ∣ x 0 ) ∫ y q ( y ∣ x 0 ) d y ⏟ = 1 = q ( x 1 : T ∣ x 0 ) (3) \begin{aligned} \hat{q}(x_{1:T}|x_0) &= \int_y \hat{q}(x_{1:T}, y|x_0) d_y \\ &= \int_y \hat{q}(x_{1:T}|y, x_0)q(y| x_0) d_y \\ &= \int_y \prod \limits_{t=1}^T \underbrace{ \hat{q}(x_t|x_{t-1}, y)}_{q(x_t|x_t-1)} q(y| x_0) d_y \\ &= \underbrace{\prod \limits_{t=1}^Tq(x_t|x_{t-1})}_{q(x_{1:T}|x_0)} \underbrace{\int_y q(y| x_0)d_y}_{=1} \\ &= q(x_{1:T}|x_0) \end{aligned}\tag{3} q^(x1:Tx0)=yq^(x1:T,yx0)dy=yq^(x1:Ty,x0)q(yx0)dy=yt=1Tq(xtxt1) q^(xtxt1,y)q(yx0)dy=q(x1:Tx0) t=1Tq(xtxt1)=1 yq(yx0)dy=q(x1:Tx0)(3)
根据上式同样可以推导出
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 , ⋯   , x t ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ^ ( x 0 ) ⏟ q ( x 0 ) q ^ ( x 1 , ⋯   , x t ∣ x 0 ) ⏟ q ( x 1 : T ∣ x 0 ) d x 0 : t − 1 = q ( x t ) (4) \begin{aligned} \hat{q}(x_t) &= \int_{x_{0:t - 1}} \hat{q}(x_0, \cdots, x_t)dx_{0:t-1} \\ &= \int_{x_{0:t - 1}} \underbrace{\hat{q}(x_0)}_{q(x_0)} \underbrace{\hat{q}(x_1, \cdots, x_t|x_0)}_{q(x_{1:T}|x_0)}dx_{0:t-1} \\ &= q(x_t) \end{aligned} \tag{4} q^(xt)=x0:t1q^(x0,,xt)dx0:t1=x0:t1q(x0) q^(x0)q(x1:Tx0) q^(x1,,xtx0)dx0:t1=q(xt)(4)
由上述推导可见带条件的DM的前向过程与DDPM完全相同。并且根据贝叶斯公式,不带逆向过程也满足
p ^ ( x t ∣ x t + 1 ) = p ( x t ∣ x t + 1 ) (5) \hat{p}(x_t|x_{t+1}) = p(x_t|x_{t+1}) \tag{5} p^(xtxt+1)=p(xtxt+1)(5)
与此同时我们可以证明分类分布 q ^ ( y ∣ x t ) \hat{q}(y|x_t) q^(yxt)只和当前时刻的输入 x t x_t xt有关,与 x t + 1 x_{t+1} xt+1无关
q ^ ( y ∣ x t , x t + 1 ) = q ^ ( x t + 1 ∣ x t , y ) ⏞ q ^ ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( y ∣ x t ) (6) \begin{aligned} \hat{q}(y|x_t, x_{t+1}) & = \frac{ \overbrace{ \hat{q}(x_{t+1}|x_t, y)}^{\hat{q}(x_{t+1}|x_t)} \hat{q}(y|x_t) } {\hat{q}(x_{t+1}|x_t )} \\ & = \hat{q}(y|x_t) \end{aligned} \tag{6} q^(yxt,xt+1)=q^(xt+1xt)q^(xt+1xt,y) q^(xt+1xt)q^(yxt)=q^(yxt)(6)

基于条件的去噪过程

将带类别信息的去噪过程定义为 p ^ ( x t ∣ x t + 1 , y ) \hat{p}(x_t|x_{t+1}, y) p^(xtxt+1,y)

p ^ ( x t ∣ x t + 1 , y ) = p ^ ( x t , x t + 1 , y ) p ^ ( y ∣ x t + 1 ) p ^ ( x t + 1 ) = p ^ ( x t , y ∣ x t + 1 ) p ^ ( y ∣ x t + 1 ) = p ^ ( y ∣ x t , x t + 1 ) ⏞ p ^ ( y ∣ x t ) p ^ ( x t ∣ x t + 1 ) ⏞ p ( x t ∣ x t + 1 ) p ^ ( y ∣ x t + 1 ) = p ^ ( y ∣ x t ) p ( x t ∣ x t + 1 ) p ^ ( y ∣ x t + 1 ) (7) \begin{aligned} \hat{p} (x_t| x_{t+1}, y) & = \frac{\hat{p} (x_t, x_{t+1}, y) }{\hat{p} (y|x_{t+1}) \hat{p} (x_{t+1}) } \\ & = \frac{\hat{p} (x_t, y | x_{t+1}) }{\hat{p} (y|x_{t+1}) } \\ & = \frac{\overbrace{\hat{p} (y|x_t, x_{t+1})}^{\hat{p}(y|x_t)} \overbrace{\hat{p}(x_t | x_{t+1})}^{p(x_t|x_{t+1})} }{\hat{p} (y|x_{t+1}) } \\ & = \frac{\hat{p} (y|x_t) p(x_t | x_{t+1}) }{\hat{p} (y|x_{t+1}) } \end{aligned} \tag{7} p^(xtxt+1,y)=p^(yxt+1)p^(xt+1)p^(xt,xt+1,y)=p^(yxt+1)p^(xt,yxt+1)=p^(yxt+1)p^(yxt,xt+1) p^(yxt)p^(xtxt+1) p(xtxt+1)=p^(yxt+1)p^(yxt)p(xtxt+1)(7)
由于 x t + 1 x_{t+1} xt+1是已知的, p ^ ( y ∣ x t + 1 ) \hat{p} (y|x_{t+1}) p^(yxt+1)这个概率分布与 x t x_t xt无关,可以将 p ^ ( y ∣ x t + 1 ) \hat{p} (y|x_{t+1}) p^(yxt+1)视为常数 Z Z Z。此时上式可以表述为
p ^ ( x t ∣ x t + 1 , y ) = Z p ^ ( y ∣ x t ) p ( x t ∣ x t + 1 ) (8) \hat{p} (x_t| x_{t+1}, y) = Z \hat{p} (y|x_t) p(x_t | x_{t+1}) \tag{8} p^(xtxt+1,y)=Zp^(yxt)p(xtxt+1)(8)
上式的右边第二项 p ^ ( y ∣ x t ) \hat{p} (y|x_t) p^(yxt)很容易得到,我们可以根据 x t , y x_t, y xt,y的pair对训练一个分类模型 p ^ ϕ ( y ∣ x t ) \hat{p}_\phi(y|x_t) p^ϕ(yxt)

上式的右边第三项 p ( x t ∣ x t + 1 ) p(x_t | x_{t+1}) p(xtxt+1)在DDPM中也能够通过一个neural network进行估计 p ( x t ∣ x t + 1 ) ≈ p θ ( x t ∣ x t + 1 ) p(x_t | x_{t+1}) \approx p_\theta(x_t|x_{t+1}) p(xtxt+1)pθ(xtxt+1)

故采样分布
p ^ ( x t ∣ x t + 1 , y ) ≈ p ^ ϕ , θ ( x t ∣ x t + 1 , y ) = Z p ^ ϕ ( y ∣ x t ) p θ ( x t ∣ x t + 1 ) (9) \begin{aligned} \hat{p} (x_t| x_{t+1}, y) &\approx \hat{p}_{\phi, \theta} (x_t| x_{t+1}, y) \\ &= Z \hat{p}_{\phi} (y|x_t) p_{\theta}(x_t | x_{t+1}) \end{aligned} \tag{9} p^(xtxt+1,y)p^ϕ,θ(xtxt+1,y)=Zp^ϕ(yxt)pθ(xtxt+1)(9)
下面来看有了上面这个式子如何进行采样

直接对上面的式子进行采样是很难解决的。论文参考文献1将上式近似为perturbed Gaussian distribution。

根据前文DM的推导可知 p θ ( x t ∣ x t + 1 ) = N ( μ , Σ ) = 1 2 π Σ exp ⁡ ( − ( x − μ ) 2 2 Σ ) p_{\theta}(x_t | x_{t+1}) = \mathcal{N}(\mu, \Sigma)=\frac{1}{\sqrt{2\pi} \sqrt{\Sigma} } \exp \left ({- \frac{(x - \mu)^2}{2\Sigma}} \right) pθ(xtxt+1)=N(μ,Σ)=2π Σ 1exp((xμ)2) ,对其取对数
log ⁡ p θ ( x t ∣ x t + 1 ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + C (10) \log p_{\theta}(x_t|x_{t+1}) = - \frac{1}{2} (x_t - \mu)^T \Sigma^{-1} (x_t - \mu) + C \tag{10} logpθ(xtxt+1)=21(xtμ)TΣ1(xtμ)+C(10)
对于 log ⁡ p ^ ϕ ( y ∣ x t ) \log \hat{p}_{\phi} (y|x_t) logp^ϕ(yxt) 作者假设其curvature比 Σ − 1 \Sigma^{-1} Σ1低。这个假设是合理的,对于当diffusion steps足够大时, ∥ Σ ∥ → 0 \parallel \Sigma \parallel \rightarrow 0 Σ∥→0。在该情况下,对 log ⁡ p ^ ϕ ( y ∣ x t ) \log\hat{p}_{\phi} (y|x_t) logp^ϕ(yxt) x t = μ x_t = \mu xt=μ处进行泰勒展开
log ⁡ p ^ ϕ ( y ∣ x t ) ≈ log ⁡ p ^ ϕ ( y ∣ x t ) ∣ x t = μ + ( x t − μ ) ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ = ( x t − μ ) g + C 1 where:  g = ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ , C 1  is a contant. (11) \begin{aligned} \log \hat{p}_{\phi} (y|x_t) & \approx \log \hat{p}_{\phi} (y|x_t) | _{x_t = \mu} + (x_t - \mu) \nabla_{x_t} \log p_{\phi} (y|x_t)|_{x_t = \mu} \\ &= (x_t - \mu) g + C_1 \\ \text{where: } g &= \nabla_{x_t} \log p_{\phi} (y|x_t)|_{x_t = \mu}, C_1\text{ is a contant.} \end{aligned} \tag{11} logp^ϕ(yxt)where: glogp^ϕ(yxt)xt=μ+(xtμ)xtlogpϕ(yxt)xt=μ=(xtμ)g+C1=xtlogpϕ(yxt)xt=μ,C1 is a contant.(11)

log ⁡ ( p ^ ϕ ( y ∣ x t ) p θ ( x t ∣ x t + 1 ) ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) g + C 2 = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + C 3 = log ⁡ p ( z ) + C 4 , z ∼ N ( μ + Σ g , Σ ) (12) \begin{aligned} \log (\hat{p}_{\phi} (y|x_t) p_{\theta}(x_t | x_{t+1})) & = - \frac{1}{2} (x_t - \mu)^T \Sigma^{-1} (x_t - \mu) + (x_t - \mu) g + C_2 \\ & = - \frac{1}{2} (x_t - \mu - \Sigma g)^T \Sigma^{-1} (x_t - \mu- \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ & = - \frac{1}{2} (x_t - \mu - \Sigma g)^T \Sigma^{-1} (x_t - \mu- \Sigma g) + C_3 \\ & = \log p(z) + C_4, z \sim \mathcal{N}(\mu + \Sigma g, \Sigma) \end{aligned} \tag{12} log(p^ϕ(yxt)pθ(xtxt+1))=21(xtμ)TΣ1(xtμ)+(xtμ)g+C2=21(xtμΣg)TΣ1(xtμΣg)+21gTΣg+C2=21(xtμΣg)TΣ1(xtμΣg)+C3=logp(z)+C4,zN(μ+Σg,Σ)(12)

(附录给出了验证性证明)

通过上述推导,我们得到了带类别条件的采样过程也可以用高斯分布来近似,只是均值需要加上 Σ g \Sigma g Σg。具体的算法如下
diffusion model(三)—— classifier guided diffusion model_第1张图片

代码实现

p_mean_var_ddpm是DDPM对高斯分布均值、方差的计算函数

p_mean_var_ddpm_with_classifier是引入类别控制后的对高斯分布均值、方差的计算函数

有了均值方差就可以进行采样了

def p_mean_var_ddpm(self, noise_model, x, t):
    """
    Math:
    \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} x_t -
        \frac{1 - \alpha_t }{\sqrt{\alpha_t}\sqrt{1 - \overline{\alpha}_t}}f_\theta(x_t, t) \tag{30}
    """
    betas_t = extract(self.betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        self.sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
    model_mean_t = sqrt_recip_alphas_t * (
        x - betas_t * noise_model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = extract(self.posterior_variance, t, x.shape)
    return model_mean_t, posterior_variance_t

  
def p_mean_var_ddpm_with_classifier(classifier, noise_model, x, t, y=None, cfs=1):
    def cond_fn(x: torch.Tensor, t: torch.Tensor, y: torch.Tensor): 
        assert y is not None
        with torch.enable_grad():
            x_in = x.detach().requires_grad_(True)
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            return torch.autograd.grad(selected.sum(), x_in)[0].float()   # gradient descend
    grad = cond_fn(x_temp, t, y=y) * cfs 
    model_mean_t, posterior_variance_t = p_mean_var_ddpm(noise_model, x, t)
    new_mean = model_mean_t + posterior_variance_t * grad
    return new_mean, posterior_variance_t
DDIM 中基于条件的去噪过程

上述条件抽样推导仅对随机扩散采样过程有效,不能应用于DDIM2等确定性采样方法(因为DDIM中设定了方差为0,故无法推导出式19)。为此,作者在研究中采用score-based的思路,参考了Song等人[^ 3]的方法,并利用了扩散模型和score matching之间的联系3

首先根据贝叶斯公式
p ( x t ∣ y ) = p ( y ∣ x t ) p ( x t ) p ( y ) ⇒ log ⁡ p ( x t ∣ y ) = log ⁡ p ( y ∣ x t ) + log ⁡ p ( x t ) − log ⁡ p ( y ) ⇒ 对 x t 求导 ∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) − ∇ x t log ⁡ p ( y ) ⏟ = 0 ⇒ ∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) (13) \begin{aligned} p (x_t| y) & = \frac{p (y|x_t) p(x_t) }{p (y) } \\ \Rightarrow \log{p (x_t| y) } &= \log{p (y|x_t)} + \log{p(x_t)} - \log{p (y) } \\ \stackrel{对x_t求导} \Rightarrow \nabla_{x_t}\log{p (x_t|y)} &= \nabla_{x_t}\log{p (y|x_t)} + \nabla_{x_t}\log{p(x_t)} - \underbrace{\nabla_{x_t}\log{p(y) }}_{=0} \\ \Rightarrow \nabla_{x_t}\log{p(x_t| y)} &= \nabla_{x_t}\log{p(y|x_t)} + \nabla_{x_t}\log{p(x_t)} \\ \end{aligned} \tag{13} p(xty)logp(xty)xt求导xtlogp(xty)xtlogp(xty)=p(y)p(yxt)p(xt)=logp(yxt)+logp(xt)logp(y)=xtlogp(yxt)+xtlogp(xt)=0 xtlogp(y)=xtlogp(yxt)+xtlogp(xt)(13)
具体来说,如果我们有一个模型 ϵ θ ( x t ) \epsilon_\theta(x_t) ϵθ(xt)来预测添加到样本中的噪声,那么可以利用它来推导出一个score function:
∇ x t log ⁡ p θ ( x t ) = − 1 1 − α ‾ t ϵ θ ( x t ) (14) \nabla_{x_t} \log p_\theta (x_t) = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}} \epsilon_\theta(x_t) \tag{14} xtlogpθ(xt)=1αt 1ϵθ(xt)(14)
代入式(20)得
∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) − 1 1 − α ‾ t ϵ θ ( x t ) ⇒ 1 − α ‾ t ∇ x t log ⁡ p ( x t ∣ y ) = 1 − α ‾ t ∇ x t log ⁡ p ( y ∣ x t ) − ϵ θ ( x t ) (15) \begin{aligned} \nabla_{x_t}\log{p(x_t| y)} &= \nabla_{x_t}\log{p(y|x_t)} - \frac{1}{\sqrt{1 - \overline{\alpha}_t}} \epsilon_\theta(x_t) \\ \Rightarrow \sqrt{1 - \overline{\alpha}_t} \nabla_{x_t}\log{p(x_t| y)} &= \sqrt{1 - \overline{\alpha}_t} \nabla_{x_t}\log{p(y|x_t)} - \epsilon_\theta(x_t) \end{aligned} \tag{15} xtlogp(xty)1αt xtlogp(xty)=xtlogp(yxt)1αt 1ϵθ(xt)=1αt xtlogp(yxt)ϵθ(xt)(15)
定义在条件 y y y下的估计噪声 ϵ ^ ( x t ∣ y ) \hat{\epsilon}(x_t|y) ϵ^(xty)为:
ϵ ^ ( x t ∣ y ) : = ϵ θ ( x t ) − 1 − α ‾ t ∇ x t log ⁡ p ϕ ( y ∣ x t ) (16) \hat{\epsilon}(x_t|y) := \epsilon_\theta(x_t) - \sqrt{1 - \overline{\alpha}_t}\nabla_{x_t} \log{p_\phi(y|x_t)} \tag{16} ϵ^(xty):=ϵθ(xt)1αt xtlogpϕ(yxt)(16)
只需将DDIM中的$ \epsilon_\theta(x_t) 替换为 替换为 替换为\hat{\epsilon}(x_t|y)$就得到了基于条件的去噪过程。

diffusion model(三)—— classifier guided diffusion model_第2张图片

代码上也很直观

def p_sample_ddim(self, model, x, t):
    """
    x_{t-1} &=  \sqrt{\overline{\alpha}_{t-1}} \frac{x_t - \sqrt{1 - \overline{\alpha}_{t}}\boldsymbol{\epsilon}_\theta(x_t, t)}
        {\sqrt{\overline{\alpha}_{t}}} +  \sqrt{1 - \overline{\alpha}_{t-1} } \boldsymbol{\epsilon}_\theta(x_t, t)
    """
    sqrt_alphas_cumprod_prev_t = extract(self.sqrt_alphas_cumprod_prev, t, x.shape) 
    sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_one_minus_alphas_cumprod_prev_t = extract(self.sqrt_one_minus_alphas_cumprod_prev, t, x.shape) 
    sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape) 
    pred_noise = model(x, t)
    pred_x0 = sqrt_alphas_cumprod_prev_t * (x - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t
    x0_direction = sqrt_one_minus_alphas_cumprod_prev_t * pred_noise 
    return pred_x0 + x0_direction
  
  
def p_sample_with_classifier(self, model, x, t, t_index, y=None, **kwargs):
    if y is None:
        return self.p_sample_ddim(model, x, t, t_index=t_index)
    cfs = kwargs.get("cfs", 1) 
    sqrt_alphas_cumprod_prev_t = extract(self.sqrt_alphas_cumprod_prev, t, x.shape) 
    sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_one_minus_alphas_cumprod_prev_t = extract(self.sqrt_one_minus_alphas_cumprod_prev, t, x.shape) 
    sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x.shape) 
    pred_noise = model(x, t)
    score = self.cond_fn(x, t, y=y) * cfs
    pred_noise = pred_noise - sqrt_one_minus_alphas_cumprod_t * score  # update noise 
    pred_x0 = sqrt_alphas_cumprod_prev_t * (x - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t
    x0_direction = sqrt_one_minus_alphas_cumprod_prev_t * pred_noise 
    return pred_x0 + x0_direction

一些细节

classifier的训练

classifier的训练与扩散模型的训练可以是独立的。在训练classifier的时候可以噪声预测模型(Unet)的encode部分作为主干,在后面接了一个分类层。并且需要与相应的扩散模型相同的噪声分布对classifier进行训练。训练数据集如 [ ( x 1 t , t , y 1 ) , ( x 2 t , t , y 2 ) , . . . , ( x N t , t , y N ) ] [(x_1^t,t, y_1), (x_2^t,t, y_2), ..., (x_N^t,t, y_N)] [(x1t,t,y1),(x2t,t,y2),...,(xNt,t,yN)] t t t是对时间步的采样, x t x^t xt x x x在时间步 t t t的输出。训练完成后,采用上面的算法集成到采样过程中。

gradient score的作用

在上面的采样算法我们看到有一个gradient scale s s s来对梯度进行拉伸。

实验视角

一般来说当 s = 1 s=1 s=1时,大约能保证生成的图片50%是想要的类别4,随着 s s s的增大,这个比例也能够增加。如下图,当 s s s增加到10,此时生成的图片都是期望的类别。因此 s s s也称之为guidance scale。

其实理解这个scale还有另一个视角

s ∇ x t log ⁡ ( p ϕ ( y ∣ x t ) ) = ∇ x t log ⁡ ( p ϕ ( y ∣ x t ) s ) s\nabla_{x_t} \log (p_\phi(y|x_t)) = \nabla_{x_t} \log (p_\phi(y|x_t)^s) sxtlog(pϕ(yxt))=xtlog(pϕ(yxt)s),当 s > 1 s>1 s>1他相当于对分布 p ϕ ( y ∣ x t ) p_\phi(y|x_t) pϕ(yxt)进行了一个指数拉升,从而带来更大的梯度更新收益。

根据DM的采样过程,当没有classifier guided时,在时刻 t t t,的采样过程应当是
x t − 1 = μ θ ( x t , t ) + σ ( t ) ϵ , 其中 ϵ ∈ N ( ϵ ; 0 , I ) = 1 α t ( x t − 1 − α t 1 − α ‾ t ϵ θ ( x t , t ) ) ⏟ μ θ ( x t , t ) + σ ( t ) ϵ (17) \begin{aligned} x_{t-1} &= \mu_{\theta}(x_t, t) + \sigma(t) \epsilon,其中 \epsilon \in \mathcal{N}(\epsilon; 0, \textbf{I}) \\ & = \underbrace{\frac{1}{\sqrt{\alpha_t}} (x_t - \frac{1 - \alpha_t }{\sqrt{1 - \overline{\alpha}_t}}\epsilon_\theta(x_t, t))}_{\mu_\theta(x_t, t)} + \sigma(t) \epsilon \end{aligned} \tag{17} xt1=μθ(xt,t)+σ(t)ϵ,其中ϵN(ϵ;0,I)=μθ(xt,t) αt 1(xt1αt 1αtϵθ(xt,t))+σ(t)ϵ(17)
当加了classifier guided相当于将 μ θ ( x t , t ) \mu_{\theta}(x_t, t) μθ(xt,t)向预测类别为 y y y的方向更新了一小步。 s s s是控制更新的幅值。
x t − 1 = μ θ ( x t , t ) + s ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ θ ( x t , t ) + σ ( t ) ϵ , 其中 ϵ ∈ N ( ϵ ; 0 , I ) \begin{align} x_{t-1} &=& \mu_{\theta}(x_t, t) + s\nabla_{x_t} \log p_{\phi} (y|x_t)|_{x_t = \mu_{\theta}(x_t, t)} + \sigma(t) \epsilon,其中 \epsilon \in \mathcal{N}(\epsilon; 0, \textbf{I}) \tag{18} \end{align} xt1=μθ(xt,t)+sxtlogpϕ(yxt)xt=μθ(xt,t)+σ(t)ϵ,其中ϵN(ϵ;0,I)(18)

参考文献

附录

式12推导验证
− 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T − μ T − g T Σ T ) Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T − μ T − g T Σ T ) Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T Σ − 1 − μ T Σ − 1 − g T Σ T Σ − 1 ⏟ g T ) ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T Σ − 1 ( x t − μ − Σ g ) − μ T Σ − 1 ( x t − μ − Σ g ) − g T ( x t − μ − Σ g ) ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t T Σ − 1 ( x t − μ ) − μ T Σ − 1 ( x t − μ ) ) ⏟ ( x t − μ ) T Σ − 1 ( x t − μ ) − 1 2 ( − g T ( x t − μ − Σ g ) + ( − x t T Σ − 1 Σ g ) ⏟ − x t T g + μ T Σ − 1 Σ g ⏟ μ T g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) g + C 2 \begin{align*} &- \frac{1}{2} (x_t - \mu - \Sigma g)^T \Sigma^{-1} (x_t - \mu- \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ = &- \frac{1}{2} (x_t^T - \mu^T - g^T \Sigma^T) \Sigma^{-1} (x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ = &- \frac{1}{2} (x_t^T - \mu^T - g^T \Sigma^T) \Sigma^{-1} (x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ \\ = & - \frac{1}{2} (x_t^T \Sigma^{-1} - \mu^T \Sigma^{-1} - \underbrace{g^T \Sigma^T \Sigma^{-1}}_{g^T} )(x_t - \mu - \Sigma g) + \frac{1}{2}g^T\Sigma g + C_2 \\ = & - \frac{1}{2} (x_t^T \Sigma^{-1} (x_t - \mu - \Sigma g) - \mu^T \Sigma^{-1} (x_t - \mu - \Sigma g) - g^T (x_t - \mu - \Sigma g)) + \frac{1}{2}g^T\Sigma g + C_2 \\ = & - \frac{1}{2} \underbrace{(x_t^T \Sigma^{-1} (x_t - \mu ) - \mu^T \Sigma^{-1} (x_t - \mu))}_{(x_t - \mu)^T \Sigma^{-1} (x_t - \mu)} - \frac{1}{2} ( - g^T (x_t - \mu - \Sigma g) + \underbrace{(- x_t^T \Sigma^{-1}\Sigma g)}_{-x_t^Tg} + \underbrace{\mu^T \Sigma^{-1}\Sigma g}_{\mu^Tg}) + \frac{1}{2}g^T\Sigma g + C_2 \\ = & - \frac{1}{2} (x_t - \mu)^T \Sigma^{-1} (x_t - \mu) + (x_t - \mu) g + C_2 \\ \end{align*} ======21(xtμΣg)TΣ1(xtμΣg)+21gTΣg+C221(xtTμTgTΣT)Σ1(xtμΣg)+21gTΣg+C221(xtTμTgTΣT)Σ1(xtμΣg)+21gTΣg+C221(xtTΣ1μTΣ1gT gTΣTΣ1)(xtμΣg)+21gTΣg+C221(xtTΣ1(xtμΣg)μTΣ1(xtμΣg)gT(xtμΣg))+21gTΣg+C221(xtμ)TΣ1(xtμ) (xtTΣ1(xtμ)μTΣ1(xtμ))21(gT(xtμΣg)+xtTg (xtTΣ1Σg)+μTg μTΣ1Σg)+21gTΣg+C221(xtμ)TΣ1(xtμ)+(xtμ)g+C2


  1. Deep unsupervised learning using nonequilibrium thermodynamics ↩︎

  2. [Denoising Diffusion Implicit Models (DDIM) Sampling](https://arxiv.org/abs/2010.02502) ↩︎

  3. Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. arXiv:arXiv:1907.05600, 2020. ↩︎

  4. Diffusion Models Beat GANs on Image Synthesis ↩︎

你可能感兴趣的:(diffusion,model,diffusion,model,guide)