“让我们用一句话,让模型画出一幅画。”
在前几期中我们学习了 Denoising Diffusion Probabilistic Models(DDPM)如何在无条件情况下生成图像。而在本期,我们将跨入更具挑战性但也更酷的领域 —— 文本条件图像生成(Text-Conditional Generation)。
本期的主角是将 CLIP 模型与扩散模型结合,使模型可以根据你输入的 一句话 来“想象”并绘制出图像。比如输入 "a photo of a cat"
,就能生成类似猫的图像。
CLIP 是由 OpenAI 提出的多模态模型,全称为 Contrastive Language–Image Pretraining。它的核心思想是:
同一个图像和它的描述性文字,在语义空间中应该越接近越好。
CLIP 同时训练了两个编码器:
图像编码器:将图像转换成一个向量(embedding)。
文本编码器:将一段文本描述转换成另一个向量。
然后通过对比学习,使得图像和它的描述之间的向量距离尽量接近。
CLIP 在这期中的作用是:将文本转化为“引导”扩散模型生成图像的向量条件。
在原始的 DDPM 中,模型仅学习如何从纯噪声恢复出图像,没有任何“指导”信息。而现在,我们希望引导它生成与某个文本语义相关的图像。
将文本输入 CLIP 的文本编码器 → 得到文本嵌入 text_embedding
。
在 UNet 中加入文本条件 → 每一层都能“感知”到你想要生成的是“猫”还是“狗”。
在每个扩散时间步中,模型接收 x_t
和 text_embedding
作为输入,预测噪声。
pip install torch torchvision
pip install git+https://github.com/openai/CLIP.git
1. 加载 CLIP 模型
import clip
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)
def get_text_embedding(texts):
tokens = clip.tokenize(texts).to(device)
with torch.no_grad():
embeddings = clip_model.encode_text(tokens)
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings
2. 定义条件 UNet
import torch.nn as nn
import torch.nn.functional as F
class UNetBlock(nn.Module):
def __init__(self, in_ch, out_ch, cond_dim):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.norm = nn.GroupNorm(8, out_ch)
self.cond_proj = nn.Linear(cond_dim, out_ch)
def forward(self, x, cond):
cond = self.cond_proj(cond).unsqueeze(-1).unsqueeze(-1)
h = self.conv1(x)
h = self.norm(h)
h = F.relu(h + cond)
h = self.conv2(h)
h = self.norm(h)
return F.relu(h + cond)
class ConditionalUNet(nn.Module):
def __init__(self, cond_dim=512):
super().__init__()
self.init = nn.Conv2d(3, 64, 3, padding=1)
self.down1 = UNetBlock(64, 128, cond_dim)
self.down2 = UNetBlock(128, 256, cond_dim)
self.middle = UNetBlock(256, 256, cond_dim)
self.up1 = UNetBlock(512, 128, cond_dim)
self.up2 = UNetBlock(256, 64, cond_dim)
self.out = nn.Conv2d(64, 3, 1)
def forward(self, x, t, cond):
x1 = self.init(x)
x2 = self.down1(x1, cond)
x3 = self.down2(x2, cond)
x4 = self.middle(x3, cond)
x = self.up1(torch.cat([x4, x3], dim=1), cond)
x = self.up2(torch.cat([x, x2], dim=1), cond)
return self.out(x)
3.定义扩散过程
T = 300
betas = torch.linspace(1e-4, 0.02, T)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
def q_sample(x_0, t, noise):
sqrt_alpha = torch.sqrt(alphas_cumprod[t])[:, None, None, None].to(x_0.device)
sqrt_one_minus_alpha = torch.sqrt(1 - alphas_cumprod[t])[:, None, None, None].to(x_0.device)
return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise
4. 训练循环(简化)
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = CIFAR10(root='./data', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
model = ConditionalUNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
for epoch in range(10):
for images, labels in dataloader:
images = images.to(device)
t = torch.randint(0, T, (images.size(0),), device=device).long()
noise = torch.randn_like(images)
# 文本标签(e.g. "cat", "airplane"...)
texts = [dataset.classes[i] for i in labels]
text_emb = get_text_embedding(texts)
x_t = q_sample(images, t, noise)
pred_noise = model(x_t, t, text_emb)
loss = F.mse_loss(pred_noise, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
5. 文本条件采样
@torch.no_grad()
def sample_with_text(text, model, steps=T):
model.eval()
x = torch.randn(16, 3, 32, 32).to(device)
text_emb = get_text_embedding([text] * 16)
for i in reversed(range(steps)):
t = torch.full((x.size(0),), i, device=device, dtype=torch.long)
pred_noise = model(x, t, text_emb)
alpha = alphas_cumprod[t][:, None, None, None].to(device)
sqrt_alpha = torch.sqrt(alpha)
sqrt_one_minus_alpha = torch.sqrt(1 - alpha)
x_0 = (x - sqrt_one_minus_alpha * pred_noise) / sqrt_alpha
x_0 = x_0.clamp(-1, 1)
if i > 0:
noise = torch.randn_like(x)
beta = betas[t][:, None, None, None].to(device)
x = torch.sqrt(alpha) * x_0 + torch.sqrt(beta) * noise
else:
x = x_0
return x
6. 结果展示
import torchvision
import matplotlib.pyplot as plt
samples = sample_with_text("a photo of a dog", model)
# 显示
samples = (samples + 1) / 2 # [-1, 1] -> [0, 1]
grid = torchvision.utils.make_grid(samples, nrow=4)
plt.figure(figsize=(6, 6))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title("Generated Images: 'a photo of a dog'")
plt.show()
本期我们使用 CLIP + Diffusion 实现了文本条件图像生成。
CLIP 提供语义嵌入作为条件,引导扩散模型生成与文字描述相符的图像。
可以扩展到更多语义场景,比如 "a cartoon of Pikachu"
、"a red sports car"
等。
下一期将进一步探索更高级的条件方式,如 交叉注意力融合文本、Classifier-free Guidance 等。