Guided Diffusion/Diffusion Models Beat GANs on Image Synthesis (Paper reading)

Guided Diffusion/Diffusion Models Beat GANs on Image Synthesis (Paper reading)

Prafulla Dhariwal, OpenAI, NeurlPS2021, Cited: 555, Code, Paper.

目录子

  • Guided Diffusion/Diffusion Models Beat GANs on Image Synthesis (Paper reading)
    • 1. 前言
    • 2. 整体思想
    • 3. 方法
    • 4. 总结

1. 前言

对于条件图像合成,我们通过分类器指导进一步提高样本质量:一种简单、计算效率高的方法,使用分类器的梯度来权衡样本质量的多样性。我们在 ImageNet 128×128 上实现了 2.97 的 FID,在 ImageNet 256×256 上实现了 4.59,在 ImageNet 512×512 上实现了 7.72。即使每个样本只有 25 次前向传播,我们也能匹配 BigGAN-deep,同时保持更好的分布覆盖。最后,我们发现分类器引导与上采样扩散模型结合得很好,在 ImageNet 512×512 上进一步将 FID 提高到 3.85。

2. 整体思想

在扩散模型的梯度上加上额外的条件梯度使扩散模型可以有条件生成。

    def cond_fn(x, t, y=None):
        assert y is not None
        with th.enable_grad():
        	# x_in 就是 x_t
            x_in = x.detach().requires_grad_(True)
            # 分类器对x_t的预测结果
            logits = classifier(x_in, t)
            # 将预测的结果转换成概率
            log_probs = F.log_softmax(logits, dim=-1)
            # 对于GT的概率
            selected = log_probs[range(len(logits)), y.view(-1)]
            # 将所有概率求和得到梯度修正
            return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

3. 方法

首先,定义一个有条件的扩散过程 q ^ \hat q q^,并假设 q ^ ( y ∣ x 0 ) \hat q(y|x_{0}) q^(yx0)是已知的:
q ^ ( x 0 ) : = q ( x 0 ) q ^ ( y ∣ x 0 ) : = Known labels per sample q ^ ( x t + 1 ∣ x t , y ) : = q ( x t + 1 ∣ x t ) q ^ ( x 1 : T ∣ x 0 , y ) : = ∏ T t = 1 q ^ ( x t ∣ x t − 1 , y ) \begin{align} \hat q(x_{0})&:=q(x_{0}) \tag{1} \\ \hat q(y|x_{0})&:= \text{Known labels per sample} \tag{2} \\ \hat q(x_{t+1}|x_{t},y) &:=q(x_{t+1}|x_{t})\tag{3}\\ \hat q(x_{1:T}|x_{0},y) &:= \prod_{T}^{t=1} \hat q(x_{t}|x_{t-1},y) \tag{4} \end{align} q^(x0)q^(yx0)q^(xt+1xt,y)q^(x1:Tx0,y):=q(x0):=Known labels per sample:=q(xt+1xt):=Tt=1q^(xtxt1,y)(1)(2)(3)(4)

(1)可知训练样本是不变的,(2)可知每一个样本都有已知标签,(3)可知加噪过程不变,(4)可知有条件的联合分布也是一个马尔可夫性质的。那么通过以上定义的假设,接下来分别证明

q ^ ( x t + 1 ∣ x t ) \hat q(x_{t+1}|x_{t}) q^(xt+1xt)的边缘分布可以表示为 ∫ y q ^ ( x t + 1 , y ∣ x t ) d y \int_{y} \hat q(x_{t+1}, y|x_{t})dy yq^(xt+1,yxt)dy,我们把 x t x_{t} xt当作一个条件时,可以将上式看作边缘概率与条件概率相乘(多变量贝叶斯公式) ∫ y q ^ ( x t + 1 ∣ x t , y ) q ^ ( y ∣ x t ) d y \int_{y} \hat q(x_{t+1}|x_{t}, y) \hat q(y|x_{t})dy yq^(xt+1xt,y)q^(yxt)dy,继续推导:
∫ y q ^ ( x t + 1 ∣ x t , y ) q ^ ( y ∣ x t ) d y = ∫ y q ^ ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) d y = q ^ ( x t + 1 ∣ x t ) ∫ y q ^ ( y ∣ x t ) d y , ∫ y q ^ ( y ∣ x t ) d y = 1 = q ^ ( x t + 1 ∣ x t ) = q ^ ( x t + 1 ∣ x t , y ) \begin{aligned} \int_{y} \hat q(x_{t+1}|x_{t}, y) \hat q(y|x_{t})dy&=\int_{y} \hat q(x_{t+1}|x_{t}) \hat q(y|x_{t})dy\\ &=\hat q(x_{t+1}|x_{t})\int_{y} \hat q(y|x_{t})dy, \quad \int_{y}\hat q(y|x_{t})dy=1\\ &=\hat q(x_{t+1}|x_{t})=\hat q(x_{t+1}|x_{t}, y) \end{aligned} yq^(xt+1xt,y)q^(yxt)dy=yq^(xt+1xt)q^(yxt)dy=q^(xt+1xt)yq^(yxt)dy,yq^(yxt)dy=1=q^(xt+1xt)=q^(xt+1xt,y)
相同的逻辑可以推断出联合分布 q ^ ( x 1 : T ∣ x 0 ) = q ( x 1 : T ∣ x 0 ) \hat q(x_{1:T}|x_{0})=q(x_{1:T}|x_{0}) q^(x1:Tx0)=q(x1:Tx0)。根据这个联合分布得:
q ^ ( x t ) = ∫ x 0 : t − 1 q ^ ( x 0 , . . . , x t ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ^ ( x 0 ) q ^ ( x 1 , . . . , x t ∣ x 0 ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ( x 0 ) q ( x 1 , . . . , x t ∣ x 0 ) d x 0 : t − 1 = ∫ x 0 : t − 1 q ( x 0 , . . . , x t ) d x 0 : t − 1 = q ( x t ) \begin{aligned} \hat q(x_{t})&=\int_{x_{0}:t-1} \hat q(x_{0},...,x_{t})d_{x_{0}:t-1}\\ &=\int_{x_{0}:t-1} \hat q(x_{0}) \hat q(x_{1},...,x_{t}|x_{0})d_{x_{0}:t-1}\\ &=\int_{x_{0}:t-1} q(x_{0})q(x_{1},...,x_{t}|x_{0})d_{x_{0}:t-1}\\ &=\int_{x_{0}:t-1} q(x_{0},...,x_{t})d_{x_{0}:t-1}\\ &=q(x_{t}) \end{aligned} q^(xt)=x0:t1q^(x0,...,xt)dx0:t1=x0:t1q^(x0)q^(x1,...,xtx0)dx0:t1=x0:t1q(x0)q(x1,...,xtx0)dx0:t1=x0:t1q(x0,...,xt)dx0:t1=q(xt)
由于 q ^ ( x t ) = q ( x t ) \hat q(x_{t})=q(x_{t}) q^(xt)=q(xt) q ^ ( x t + 1 ∣ x t ) = q ( x t + 1 ∣ x t ) \hat q(x_{t+1}|x_{t})= q(x_{t+1}|x_{t}) q^(xt+1xt)=q(xt+1xt),由贝叶斯公式可知任意时刻 x t x_{t} xt q ^ \hat q q^ q q q中的表现是一样的。那么在有条件下的逆过程 q ^ ( x t ∣ x t + 1 , y ) \hat q(x_{t}|x_{t+1},y) q^(xtxt+1,y)是什么样的呢?
q ^ ( x t ∣ x t + 1 , y ) = q ^ ( x t , x t + 1 , y ) q ^ ( x t + 1 , y ) = q ^ ( x t , x t + 1 , y ) q ^ ( y ∣ x t + 1 ) q ^ ( x t + 1 ) = q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t , x t + 1 ) q ^ ( x t + 1 ) q ^ ( y ∣ x t + 1 ) q ^ ( x t + 1 ) = q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t , x t + 1 ) q ^ ( y ∣ x t + 1 ) \begin{align} \hat q(x_{t}|x_{t+1},y)&=\frac{\hat q(x_{t},x_{t+1},y)}{\hat q(x_{t+1},y)}\tag{5}\\ &=\frac{\hat q(x_{t},x_{t+1},y)}{\hat q(y|x_{t+1})\hat q(x_{t+1})}\tag{6}\\ &=\frac{\hat q(x_{t}|x_{t+1}) \hat q(y|x_{t},x_{t+1}) \hat q(x_{t+1})}{\hat q(y|x_{t+1}) \hat q(x_{t+1})}\tag{7}\\ &=\frac{\hat q(x_{t}|x_{t+1}) \hat q(y|x_{t},x_{t+1})}{\hat q(y|x_{t+1})}\tag{8}\\ \end{align} q^(xtxt+1,y)=q^(xt+1,y)q^(xt,xt+1,y)=q^(yxt+1)q^(xt+1)q^(xt,xt+1,y)=q^(yxt+1)q^(xt+1)q^(xtxt+1)q^(yxt,xt+1)q^(xt+1)=q^(yxt+1)q^(xtxt+1)q^(yxt,xt+1)(5)(6)(7)(8)
公式8中的 q ^ ( y ∣ x t , x t + 1 ) \hat q(y|x_{t},x_{t+1}) q^(yxt,xt+1)
q ^ ( y ∣ x t , x t + 1 ) = q ^ ( x t + 1 ∣ x t , y ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( y ∣ x t ) \begin{align} \hat q(y|x_{t},x_{t+1})&=\hat q(x_{t+1}|x_{t},y)\frac{\hat q(y|x_{t})}{\hat q(x_{t+1}|x_{t})}\tag{9}\\ &=\hat q(x_{t+1}|x_{t})\frac{\hat q(y|x_{t})}{\hat q(x_{t+1}|x_{t})}\tag{10}\\ &=\hat q(y|x_{t})\tag{11} \end{align} q^(yxt,xt+1)=q^(xt+1xt,y)q^(xt+1xt)q^(yxt)=q^(xt+1xt)q^(xt+1xt)q^(yxt)=q^(yxt)(9)(10)(11)
则公式8可以表示为: q ( x t ∣ x t + 1 ) q ^ ( y ∣ x t ) q ^ ( y ∣ x t + 1 ) \frac{q(x_{t}|x_{t+1}) \hat q(y|x_{t})}{\hat q(y|x_{t+1})} q^(yxt+1)q(xtxt+1)q^(yxt),且 q ^ ( y ∣ x t + 1 ) \hat q(y|x_{t+1}) q^(yxt+1)是不依赖于 x t x_{t} xt的,则 q ^ ( y ∣ x t , x t + 1 ) ∝ Z q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t , x t + 1 ) \hat q(y|x_{t},x_{t+1}) \propto Z\hat q(x_{t}|x_{t+1}) \hat q(y|x_{t},x_{t+1}) q^(yxt,xt+1)Zq^(xtxt+1)q^(yxt,xt+1) Z Z Z是一个归一化常数。 q q q是一个训练好的扩散过程,因此可以用 p θ ( x t ∣ x t + 1 ) p_{\theta}(x_{t}|x_{t+1}) pθ(xtxt+1)表示,那么该如何获得 q ^ ( y ∣ x t ) \hat q(y|x_{t}) q^(yxt)?我们可以在每一个时刻 x t x_{t} xt,训练一个分类器 p ϕ ( y ∣ x t ) p_{\phi}(y|x_{t}) pϕ(yxt)。接下来如何通过采样得到 x 0 x_{0} x0呢?
p θ , ϕ ( x t ∣ x t + 1 , y ) = Z p θ ( x t ∣ x t + 1 ) p ϕ ( y ∣ x t ) (12) p_{\theta, \phi}(x_{t}|x_{t+1},y)=Zp_{\theta}(x_{t}|x_{t+1})p_{\phi}(y|x_{t})\tag{12} pθ,ϕ(xtxt+1,y)=Zpθ(xtxt+1)pϕ(yxt)(12)
在DDPM中,假设了 p θ ( x t ∣ x t + 1 ) = N ( μ , Σ ) p_{\theta}(x_{t}|x_{t+1})=N(\mu,\Sigma) pθ(xtxt+1)=N(μ,Σ), 该多元高斯分布的似然函数为 l o g p θ ( x t ∣ x t + 1 ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + C logp_{\theta}(x_{t}|x_{t+1})=-\frac{1}{2}(x_{t}-\mu)^{T}\Sigma^{-1}(x_{t}-\mu)+C logpθ(xtxt+1)=21(xtμ)TΣ1(xtμ)+C 。同时,我们假设 l o g p ϕ ( y ∣ x t ) log p_{\phi}(y|x_{t}) logpϕ(yxt) Σ − 1 \Sigma^{-1} Σ1有更低的曲率,也就是说二阶导数更小。那么我们可以用Taylor公式将 l o g p ϕ ( y ∣ x t ) logp_{\phi}(y|x_{t}) logpϕ(yxt)展开到一阶导数:
l o g p ϕ ( y ∣ x t ) ≈ l o g p ϕ ( y ∣ x t ) ∣ x t = μ + ( x t − μ ) ∇ x t l o g p ϕ ( y ∣ x t ) ∣ x t = μ = ( x t − μ ) g + C 1 (13) \begin{aligned} log p_{\phi}(y|x_{t})&\approx logp_{\phi}(y|x_{t})|_{x_{t}=\mu}+(x_{t}-\mu)\nabla_{x_{t}}log p_{\phi}(y|x_{t})|_{x_{t}=\mu}\\ &=(x_{t}-\mu)g+C_{1} \tag{13} \end{aligned} logpϕ(yxt)logpϕ(yxt)xt=μ+(xtμ)xtlogpϕ(yxt)xt=μ=(xtμ)g+C1(13)
其中 g = ∇ x t l o g p ϕ ( y ∣ x t ) ∣ x t = μ g= \nabla_{x_{t}}log p_{\phi}(y|x_{t})|_{x_{t}=\mu} g=xtlogpϕ(yxt)xt=μ C 1 C_{1} C1是个常数。 则:
l o g ( p θ ( x t ∣ x t + 1 ) p ϕ ( y ∣ x t ) ) ≈ − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) g + C 2 = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + C 3 = l o g p ( z ) + C 4 , z ∼ N ( μ + Σ g , Σ ) \begin{align} log(p_{\theta}(x_{t}|x_{t+1})p_{\phi}(y|x_{t}))&\approx -\frac{1}{2}(x_{t}-\mu)^{T}\Sigma^{-1}(x_{t}-\mu)+(x_{t}-\mu)g+C_{2}\tag{14}\\ &= -\frac{1}{2}(x_{t}-\mu-\Sigma g)^{T}\Sigma^{-1}(x_{t}-\mu-\Sigma g)+\frac{1}{2}g ^{T}\Sigma g+C_{2}\tag{15}\\ &=-\frac{1}{2}(x_{t}-\mu-\Sigma g)^{T}\Sigma^{-1}(x_{t}-\mu-\Sigma g)+C_{3}\tag{16}\\ &=logp(z)+C_{4},z\sim N(\mu+\Sigma g, \Sigma)\tag{17} \end{align} log(pθ(xtxt+1)pϕ(yxt))21(xtμ)TΣ1(xtμ)+(xtμ)g+C2=21(xtμΣg)TΣ1(xtμΣg)+21gTΣg+C2=21(xtμΣg)TΣ1(xtμΣg)+C3=logp(z)+C4,zN(μ+Σg,Σ)(14)(15)(16)(17)
公式14中的 Σ = Σ T \Sigma=\Sigma^{T} Σ=ΣT可以推出公式15,而公式15中的第二项与 x t x_{t} xt无关,则可以推出公式16,公式17说明了这个新的似然为均值为 μ + Σ g \mu+\Sigma g μ+Σg,方差不变的新分布,也就是说新的逆扩散条件分布 p θ , ϕ p_{\theta,\phi} pθ,ϕ与DDPM中的逆过程仅仅相差在均值上 。那么将该方法应用在DDPM中,如下:
Guided Diffusion/Diffusion Models Beat GANs on Image Synthesis (Paper reading)_第1张图片

首先通过扩散模型中的神经网络获得均值和方差,可学习的方差为 Σ θ ( x t , t ) = e x p ( v l o g ( β t ) + ( 1 − v ) l o g β ~ t ) \Sigma_{\theta}(x_{t}, t)=exp(vlog(\beta_{t})+(1-v)log\tilde\beta_{t}) Σθ(xt,t)=exp(vlog(βt)+(1v)logβ~t)其中, β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde\beta_{t}=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_{t}}\beta_{t} β~t=1αˉt1αˉt1βt。然后我们从新的分布中进行采样得到 x t − 1 x_{t-1} xt1,这里的 g g g就是 x t x_{t} xt时刻分类器的梯度。

4. 总结

条件扩散模型是相对热点的研究方向,本文的添加条件的方法是一种简单高效的方法,应用于low-level任务中的文章目前也被研究探讨,与之前的条件扩散模型相比,guided diffusion models对于一些特殊的任务更友好,而且openai的代码质量也很高。本文还有很多值得讨论的内容,如神经网络的修改,DDIM的应用等。

你可能感兴趣的:(图像处理,扩散模型,python,python,算法,人工智能,计算机视觉,深度学习)