URL | |
---|---|
paper | Classifier-Free Diffusion Guidance GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models |
github | https://github.com/openai/glide-text2im |
在classifier-guided这篇博客我们提到对于一般的DM(如DDPM, DDIM)的采样过程是直接从一个噪声分布,通过不断采样来生成图片。但这个方法生成的图片类别是随机的,classifier-guided通过额外训练一个分类器来不断矫正每一个时间步的生成图片,最终实现特定类别图片的生成。
Classifier-free的核心思路是:我们无需训练额外的分类器,直接训练带类别信息的噪声预测模型来实现特定类别图片的生成,即 ϵ θ ( x t , t ) → ϵ ^ θ ( x t , y , t ) \epsilon_{\theta}(x_t, t) \rightarrow \hat{\epsilon}_{\theta}(x_t, y, t) ϵθ(xt,t)→ϵ^θ(xt,y,t)。从而简化整体的pipeline。
此外,classifier-free方法不局限于类别信息的融入,它还能实现将语义信息融入到diffusion model中,实现更为灵活的文生图。这用classifier-guide是很难做到的。目前的很多工作如DALLE,Stable Diffusion, Imagen等都是Classifier-free形式。如:
下面我们来看他是怎么做的吧!
classifier-free diffusion的实现非常简单。下面对比普通的diffusion model,classifier-guided与classifier-free三种方式的差异。
模型 | 训练目标 | 实现功能 | 训练数据 |
---|---|---|---|
DM (DDPM, DDIM) | ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t) | 从服从高斯分布的噪声中生成图片 | 图片 |
classifier-guided DM | ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)和分类器 p ( y ∣ x t ) p(y|x_t) p(y∣xt) | 从服从高斯分布的噪声中生成特定类别的图片 | DM:图片 分类器:图片-标签对 |
classifier-free DM | ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t), ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t) | 从服从高斯分布的噪声中生成符合文本描述的图片 | 图片-文本对 |
回答3个问题深入理解classifier-free DM
我们知道,深度学习模型推理的本质可以理解为一系列的数值计算,因此将类别信息(或语义信息)融入到模型中需要预先将其转化为数值。转化的方法有很多,如可以用一个embedding layer
。也可以用NLP模型,如Bert、T5、CLIP的text encoder等将类别信息(或语义信息)转化为数值向量,一般称为text embedding
。随后需要将text embedding
和原本模型中的image representation
进行融合。最为常见且有效的方法是用交叉注意力机制CrossAttention
。具体来说就是将text embedding
作为注意力机制中的key
和value
,原始的图片表征作为query
。大家熟知的Stable Diffusion用的就是这个融入方法。交叉注意力机制融入语义信息的本质是spatial-wise attention。
class SpatialCrossAttention(nn.Module):
def __init__(self, dim, context_dim, heads=4, dim_head=32) -> None:
super(SpatialCrossAttention, self).__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.proj_in = nn.Conv2d(dim, context_dim, kernel_size=1, stride=1, padding=0)
self.to_q = nn.Linear(context_dim, hidden_dim, bias=False)
self.to_k = nn.Linear(context_dim, hidden_dim, bias=False)
self.to_v = nn.Linear(context_dim, hidden_dim, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x, context=None):
x_q = self.proj_in(x)
b, c, h, w = x_q.shape
x_q = rearrange(x_q, "b c h w -> b (h w) c")
if context is None:
context = x_q
if context.ndim == 2:
context = rearrange(context, "b c -> b () c")
q = self.to_q(x_q)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=self.heads)
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w)
out = self.to_out(out)
return out
该融入方法与time-embedding
的融入方法相同,在时间中往往会预先和time-embedding
进行融合,再融入到图片特征中,伪代码如下:
# mixture time-embedding and label embedding
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
emb_out = self.emb_layers(emb).type(h.dtype) # h is image feature
scale, shift = th.chunk(emb_out, 2, dim=1) # One half of the embedding is used for scaling and the other half for offset
h = h * (1 + scale) + shift
基于channel-wise的融入粒度没有CrossAttention
细。一般适用类别数量有限的特征融入,如时间embedding,类别embedding。而语义信息的融入更推荐上面CrossAttention
的方法。
ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t)的训练需要图文对,但互联网上具备文本描述的图片只是浩如烟海的图片海洋中的一小部分。仅用具备图文对数据训练 ϵ θ ( x t , y , t ) \epsilon_{\theta}(x_t, y,t) ϵθ(xt,y,t)将会大大束缚DM的生成多样性。另外,为了使得模型更好的捕获图文的联系 ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=∅,t)的数据不宜过多,否则模型生成结果的保真度会降低。反之,若 ϵ θ ( x t , y = ∅ , t ) \epsilon_{\theta}(x_t, y=\empty,t) ϵθ(xt,y=∅,t)数据过少,将会影响生成结果的多样性。需要根据实际的场景进行调整。
有两个实践中的trick需要注意:
classifier-free diffusion的采样生成过程与前面介绍的DDPM,DDIM类似。唯一有所区别的是将原本的 ϵ ( x t , t ) \epsilon(x_t, t) ϵ(xt,t)用下式代替。
ϵ ^ θ ( x t , y , t ) = ϵ θ ( x t , y = ∅ , t ) + s [ ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ] \begin{align} \hat{\epsilon}_{\theta}(x_t, y, t)=\epsilon_{\theta}(x_t, y=\empty,t) + s[\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ]\tag{1} \end{align} ϵ^θ(xt,y,t)=ϵθ(xt,y=∅,t)+s[ϵθ(xt,y,t)−ϵθ(xt,y=∅,t)](1)
下面给出详细的推导过程:
首先根据贝叶斯公式有
p ( y ∣ x t ) = p ( x t ∣ y ) p ( y ) ⏞ 先验分布 p ( x t ) ⇒ p ( y ∣ x t ) ∝ p ( x t ∣ y ) / p ( x t ) ⇒ 取对数 log p ( y ∣ x t ) = log p ( x t ∣ y ) − log p ( x t ) ⇒ 对 x t 求导 ∇ x t log p ( y ∣ x t ) = ∇ x t log p ( x t ∣ y ) − ∇ x t log p ( x t ) ⇒ 根据score function ∇ x t log p θ ( x t ) = − 1 1 − α ‾ t ϵ θ ( x t ) ∇ x t log p ( y ∣ x t ) = − 1 1 − α ‾ t ( ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ) (2) \begin{aligned} p (y| x_t) & = \frac{p (x_t|y) \overbrace{p(y)}^{\text{先验分布}} } {p(x_t) } \\ \Rightarrow p (y| x_t) & \propto p (x_t|y) / {p (x_t) } \\ \stackrel{取对数} \Rightarrow \log{p (y| x_t)} & = \log{p (x_t|y)} - \log{{p (x_t) }} \\ \stackrel{对x_t求导} \Rightarrow \nabla_{x_t}\log{p (y| x_t)} & = \nabla_{x_t}\log{p (x_t|y)} - \nabla_{x_t}\log{{p (x_t) }} \\ \stackrel{\text{根据score function} \nabla_{x_t} \log p_\theta (x_t) = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}} \epsilon_\theta(x_t)} \Rightarrow \nabla_{x_t}\log{p (y| x_t)} & = - \frac{1}{\sqrt{1 - \overline{\alpha}_t}}(\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ) \end{aligned} \tag{2} p(y∣xt)⇒p(y∣xt)⇒取对数logp(y∣xt)⇒对xt求导∇xtlogp(y∣xt)⇒根据score function∇xtlogpθ(xt)=−1−αt1ϵθ(xt)∇xtlogp(y∣xt)=p(xt)p(xt∣y)p(y) 先验分布∝p(xt∣y)/p(xt)=logp(xt∣y)−logp(xt)=∇xtlogp(xt∣y)−∇xtlogp(xt)=−1−αt1(ϵθ(xt,y,t)−ϵθ(xt,y=∅,t))(2)
当我们得到 ∇ x t log p ( y ∣ x t ) \nabla_{x_t}\log{p (y| x_t)} ∇xtlogp(y∣xt),参考classifier-guided的式(17)
ϵ ^ ( x t ∣ y ) ⏟ 本文中的 ϵ ^ θ ( x t , y , t ) : = ϵ θ ( x t ) ⏟ 本文中的 ϵ θ ( x t , y = ∅ , t ) − s 1 − α ‾ t ∇ x t log p ϕ ( y ∣ x t ) (3) \underbrace{\hat{\epsilon}(x_t|y)}_{\text{本文中的}\hat{\epsilon}_{\theta}(x_t, y, t)} := \underbrace{\epsilon_\theta(x_t)}_{\text{本文中的}\epsilon_{\theta}(x_t, y=\empty, t)} - s\sqrt{1 - \overline{\alpha}_t}\nabla_{x_t} \log{p_\phi(y|x_t)} \tag{3} 本文中的ϵ^θ(xt,y,t) ϵ^(xt∣y):=本文中的ϵθ(xt,y=∅,t) ϵθ(xt)−s1−αt∇xtlogpϕ(y∣xt)(3)
可得
ϵ ^ θ ( x t , y , t ) = ϵ θ ( x t , y = ∅ , t ) + s [ ϵ θ ( x t , y , t ) − ϵ θ ( x t , y = ∅ , t ) ] \begin{align} \hat{\epsilon}_{\theta}(x_t, y, t)=\epsilon_{\theta}(x_t, y=\empty,t) + s[\epsilon_{\theta}(x_t, y, t) - \epsilon_{\theta}(x_t, y=\empty, t) ]\tag{4} \end{align} ϵ^θ(xt,y,t)=ϵθ(xt,y=∅,t)+s[ϵθ(xt,y,t)−ϵθ(xt,y=∅,t)](4)
后面的采样过程与之前的方式一致。
本文详细介绍了classifier-free的提出背景与具体实现方案。它是后续一系列如stable diffusion,DALLE等文生图工作的基石。
[1]: Classifier-Free Diffusion Guidance
[2]: GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models