第9期:文本条件生成(CLIP + Diffusion)详解

“让我们用一句话,让模型画出一幅画。”

在前几期中我们学习了 Denoising Diffusion Probabilistic Models(DDPM)如何在无条件情况下生成图像。而在本期,我们将跨入更具挑战性但也更酷的领域 —— 文本条件图像生成(Text-Conditional Generation)

本期的主角是将 CLIP 模型与扩散模型结合,使模型可以根据你输入的 一句话 来“想象”并绘制出图像。比如输入 "a photo of a cat",就能生成类似猫的图像。

一、什么是 CLIP?

CLIP 是由 OpenAI 提出的多模态模型,全称为 Contrastive Language–Image Pretraining。它的核心思想是:

同一个图像和它的描述性文字,在语义空间中应该越接近越好。

CLIP 同时训练了两个编码器:

  • 图像编码器:将图像转换成一个向量(embedding)。

  • 文本编码器:将一段文本描述转换成另一个向量。

然后通过对比学习,使得图像和它的描述之间的向量距离尽量接近。

CLIP 在这期中的作用是:将文本转化为“引导”扩散模型生成图像的向量条件

二、扩散模型如何结合文本?

在原始的 DDPM 中,模型仅学习如何从纯噪声恢复出图像,没有任何“指导”信息。而现在,我们希望引导它生成与某个文本语义相关的图像。

做法:

  1. 将文本输入 CLIP 的文本编码器 → 得到文本嵌入 text_embedding

  2. 在 UNet 中加入文本条件 → 每一层都能“感知”到你想要生成的是“猫”还是“狗”。

  3. 在每个扩散时间步中,模型接收 x_ttext_embedding 作为输入,预测噪声。

完整实现(含代码)

安装依赖

pip install torch torchvision
pip install git+https://github.com/openai/CLIP.git

1. 加载 CLIP 模型

import clip
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

def get_text_embedding(texts):
    tokens = clip.tokenize(texts).to(device)
    with torch.no_grad():
        embeddings = clip_model.encode_text(tokens)
        embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
    return embeddings

2. 定义条件 UNet

import torch.nn as nn
import torch.nn.functional as F

class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, cond_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm = nn.GroupNorm(8, out_ch)
        self.cond_proj = nn.Linear(cond_dim, out_ch)

    def forward(self, x, cond):
        cond = self.cond_proj(cond).unsqueeze(-1).unsqueeze(-1)
        h = self.conv1(x)
        h = self.norm(h)
        h = F.relu(h + cond)
        h = self.conv2(h)
        h = self.norm(h)
        return F.relu(h + cond)

class ConditionalUNet(nn.Module):
    def __init__(self, cond_dim=512):
        super().__init__()
        self.init = nn.Conv2d(3, 64, 3, padding=1)
        self.down1 = UNetBlock(64, 128, cond_dim)
        self.down2 = UNetBlock(128, 256, cond_dim)
        self.middle = UNetBlock(256, 256, cond_dim)
        self.up1 = UNetBlock(512, 128, cond_dim)
        self.up2 = UNetBlock(256, 64, cond_dim)
        self.out = nn.Conv2d(64, 3, 1)

    def forward(self, x, t, cond):
        x1 = self.init(x)
        x2 = self.down1(x1, cond)
        x3 = self.down2(x2, cond)
        x4 = self.middle(x3, cond)
        x = self.up1(torch.cat([x4, x3], dim=1), cond)
        x = self.up2(torch.cat([x, x2], dim=1), cond)
        return self.out(x)

3.定义扩散过程

T = 300
betas = torch.linspace(1e-4, 0.02, T)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

def q_sample(x_0, t, noise):
    sqrt_alpha = torch.sqrt(alphas_cumprod[t])[:, None, None, None].to(x_0.device)
    sqrt_one_minus_alpha = torch.sqrt(1 - alphas_cumprod[t])[:, None, None, None].to(x_0.device)
    return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise

4. 训练循环(简化)

from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = CIFAR10(root='./data', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

model = ConditionalUNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

for epoch in range(10):
    for images, labels in dataloader:
        images = images.to(device)
        t = torch.randint(0, T, (images.size(0),), device=device).long()
        noise = torch.randn_like(images)

        # 文本标签(e.g. "cat", "airplane"...)
        texts = [dataset.classes[i] for i in labels]
        text_emb = get_text_embedding(texts)

        x_t = q_sample(images, t, noise)
        pred_noise = model(x_t, t, text_emb)

        loss = F.mse_loss(pred_noise, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

5. 文本条件采样

@torch.no_grad()
def sample_with_text(text, model, steps=T):
    model.eval()
    x = torch.randn(16, 3, 32, 32).to(device)
    text_emb = get_text_embedding([text] * 16)

    for i in reversed(range(steps)):
        t = torch.full((x.size(0),), i, device=device, dtype=torch.long)
        pred_noise = model(x, t, text_emb)

        alpha = alphas_cumprod[t][:, None, None, None].to(device)
        sqrt_alpha = torch.sqrt(alpha)
        sqrt_one_minus_alpha = torch.sqrt(1 - alpha)
        x_0 = (x - sqrt_one_minus_alpha * pred_noise) / sqrt_alpha
        x_0 = x_0.clamp(-1, 1)

        if i > 0:
            noise = torch.randn_like(x)
            beta = betas[t][:, None, None, None].to(device)
            x = torch.sqrt(alpha) * x_0 + torch.sqrt(beta) * noise
        else:
            x = x_0
    return x

6. 结果展示

import torchvision
import matplotlib.pyplot as plt

samples = sample_with_text("a photo of a dog", model)

# 显示
samples = (samples + 1) / 2  # [-1, 1] -> [0, 1]
grid = torchvision.utils.make_grid(samples, nrow=4)
plt.figure(figsize=(6, 6))
plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
plt.axis('off')
plt.title("Generated Images: 'a photo of a dog'")
plt.show()

总结

  • 本期我们使用 CLIP + Diffusion 实现了文本条件图像生成。

  • CLIP 提供语义嵌入作为条件,引导扩散模型生成与文字描述相符的图像。

  • 可以扩展到更多语义场景,比如 "a cartoon of Pikachu""a red sports car" 等。

  • 下一期将进一步探索更高级的条件方式,如 交叉注意力融合文本Classifier-free Guidance 等。

你可能感兴趣的:(人工智能,深度学习)