Stable Diffusion 中 Cross Attention 实现原理解析(含代码讲解)

在 Stable Diffusion 的 U-Net 中,Cross Attention 是将文本提示与图像特征对齐融合的关键模块,本文将结合一段 Python 实现代码,逐行解释其原理。

Cross Attention 是什么?

Cross Attention(交叉注意力)是指:让一组 Query 向量(比如图像特征)去 attend 另一组 Key-Value 向量(比如文本上下文),以融合跨模态信息。

在 Stable Diffusion 中,它的作用就是:

让图像特征与文本 Prompt 对齐,从而生成符合描述的图像。


代码实现

下面是一个精简版的 cross attention 实现(基于 PyTorch):

import torch
import torch.nn as nn
import torch.nn.functional as F

def cross_attention(x: 'b c h w', context: 'b len dim'):
    batch_size, channels, height, width = x.shape
    x_flat = x.view(batch_size, channels, -1).permute(0, 2, 1)  # (b, h*w, c)

    q = nn.Linear(channels, dim)(x_flat)      # (b, h*w, dim)
    k = nn.Linear(dim, dim)(context)          # (b, len, dim)
    v = nn.Linear(dim, dim)(context)          # (b, len, dim)

    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (dim ** 0.5)  # (b, h*w, len)
    attn_weights = F.softmax(attn_scores, dim=-1)                      # (b, h*w, len)

    attn_output = torch.matmul(attn_weights, v)        # (b, h*w, dim)
    out = nn.Linear(dim, channels)(attn_output)        # (b, h*w, c)
    out = out.permute(0, 2, 1).view(batch_size, channels, height, width)
    return out

每一行在干什么?

1️⃣ 输入格式
• x: 图像特征图,形状为 (batch_size, channels, height, width)。
• context: 上下文特征,通常是文本 Prompt 编码,形状为 (batch_size, seq_len, dim)。
2️⃣ 将图像特征展平

x_flat = x.view(batch_size, channels, -1).permute(0, 2, 1)  # (b, h*w, c)

将二维空间的图像特征展平成一个序列,即把每个像素位置看作一个 token。
3️⃣ 构造 Query / Key / Value

q = nn.Linear(channels, dim)(x_flat)  # 查询向量,图像发出请求
k = nn.Linear(dim, dim)(context)      # 键向量,文本提供信息索引
v = nn.Linear(dim, dim)(context)      # 值向量,文本的实际信息内容

这里 q 是来自图像,k 和 v 是来自文本。
4️⃣ 计算注意力得分

attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (dim ** 0.5)

这是标准的 Scaled Dot-Product Attention:
对每个图像位置的 q 与所有文本位置的 k 做点积。
5️⃣ softmax 得到注意力权重

attn_weights = F.softmax(attn_scores, dim=-1)

将得分转化为概率分布,得到每个图像位置对文本各 token 的注意力程度。
6️⃣ 加权求和融合文本特征

attn_output = torch.matmul(attn_weights, v)  # (b, h*w, dim)

使用注意力权重对文本的值向量 v 做加权,得到融合文本信息后的图像 token。
7️⃣ 映射回原图像空间

out = nn.Linear(dim, channels)(attn_output)
out = out.permute(0, 2, 1).view(batch_size, channels, height, width)

将融合后的特征映射回原始的通道维度,并 reshape 成 (b, c, h, w) 格式。

你可能感兴趣的:(面试问题,stable,diffusion,人工智能,深度学习)