“如何让扩散模型更强地听话?”
在第9期中,我们成功实现了 CLIP + Diffusion 的文本引导图像生成模型。虽然能生成与文本描述大致匹配的图像,但你会发现:
有时图像与文本契合度不高;
控制力不够,输入 "a red car"
也可能出现黄色或蓝色;
文本引导效果“太温柔”了。
这时候,Classifier-Free Guidance(简称 CFG)就派上用场了。
传统的条件扩散方法(如 Guided Diffusion)使用一个额外的分类器去引导模型生成符合某个类别的图像。这种方式虽然有效,但缺点明显:
需要额外训练分类器;
无法用于文本这样的复杂条件;
效率较低。
CFG 的核心思想是:
不训练额外的分类器,而是利用 同一个模型同时学习条件生成和无条件生成,然后在推理时动态融合两者。
训练阶段:
有条件样本时,使用文本条件训练。
一定概率下去除条件,作为无条件训练样本。
推理阶段:
同时用模型计算条件预测和无条件预测;
用一个超参数 guidance_scale
进行融合:final_output = uncond + guidance_scale * (cond - uncond)
我们沿用第9期的 ConditionalUNet
模型,添加无条件路径,并实现带 guidance_scale
的采样流程。
1. 修改 UNet 支持 None 条件
class ConditionalUNet(nn.Module):
def __init__(self, cond_dim=512):
super().__init__()
self.cond_dim = cond_dim
self.init = nn.Conv2d(3, 64, 3, padding=1)
self.down1 = UNetBlock(64, 128, cond_dim)
self.down2 = UNetBlock(128, 256, cond_dim)
self.middle = UNetBlock(256, 256, cond_dim)
self.up1 = UNetBlock(512, 128, cond_dim)
self.up2 = UNetBlock(256, 64, cond_dim)
self.out = nn.Conv2d(64, 3, 1)
def forward(self, x, t, cond=None):
if cond is None:
cond = torch.zeros(x.size(0), self.cond_dim).to(x.device)
x1 = self.init(x)
x2 = self.down1(x1, cond)
x3 = self.down2(x2, cond)
x4 = self.middle(x3, cond)
x = self.up1(torch.cat([x4, x3], dim=1), cond)
x = self.up2(torch.cat([x, x2], dim=1), cond)
return self.out(x)
2. 添加 Classifier-Free Guidance 的采样函数
@torch.no_grad()
def sample_with_cfg(text, model, steps=T, guidance_scale=5.0):
model.eval()
x = torch.randn(16, 3, 32, 32).to(device)
text_emb = get_text_embedding([text] * x.size(0))
uncond_emb = torch.zeros_like(text_emb)
for i in reversed(range(steps)):
t = torch.full((x.size(0),), i, device=device, dtype=torch.long)
# 条件和无条件输出
pred_cond = model(x, t, text_emb)
pred_uncond = model(x, t, uncond_emb)
pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
alpha = alphas_cumprod[t][:, None, None, None].to(x.device)
sqrt_alpha = torch.sqrt(alpha)
sqrt_one_minus_alpha = torch.sqrt(1 - alpha)
x_0 = (x - sqrt_one_minus_alpha * pred) / sqrt_alpha
x_0 = x_0.clamp(-1, 1)
if i > 0:
noise = torch.randn_like(x)
beta = betas[t][:, None, None, None].to(x.device)
x = torch.sqrt(alpha) * x_0 + torch.sqrt(beta) * noise
else:
x = x_0
return x
3. 可视化结果对比
import torchvision
import matplotlib.pyplot as plt
samples = sample_with_cfg("a red sports car", model, guidance_scale=5.0)
samples = (samples + 1) / 2
grid = torchvision.utils.make_grid(samples, nrow=4)
plt.figure(figsize=(6, 6))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title("Generated: 'a red sports car'")
plt.show()
项目 | 优点 | 缺点 |
---|---|---|
无需额外分类器 | ✅ 简化训练流程 | ❌ 推理中需双前向 |
可用于复杂条件 | ✅ 支持文本、图像 | ❌ guidance_scale 需要调参 |
效果提升明显 | ✅ 提升文本一致性 | - |
多条件引导(文本 + 类别 + 图像);
用交叉注意力提升文本感知;
使用 pretrained 文本嵌入(如 T5/BERT)增强语义理解;
更大的 UNet backbone(例如 Stable Diffusion 的 UNet2DConditionModel);
多阶段生成(先粗图,再细化);
本期我们详细介绍了 Classifier-Free Guidance(CFG) 方法;
展示了如何融合条件和无条件的输出提升控制力;
用一段简洁代码实现了比第9期更强的文本引导效果;
CFG 是当前主流扩散文本生成方法(如 Imagen、Stable Diffusion)中核心组件。