在 Stable Diffusion 的 U-Net 中,Cross Attention
是将文本提示与图像特征对齐融合的关键模块,本文将结合一段 Python 实现代码,逐行解释其原理。
Cross Attention(交叉注意力)是指:让一组 Query 向量(比如图像特征)去 attend 另一组 Key-Value 向量(比如文本上下文),以融合跨模态信息。
在 Stable Diffusion 中,它的作用就是:
让图像特征与文本 Prompt 对齐,从而生成符合描述的图像。
下面是一个精简版的 cross attention 实现(基于 PyTorch):
import torch
import torch.nn as nn
import torch.nn.functional as F
def cross_attention(x: 'b c h w', context: 'b len dim'):
batch_size, channels, height, width = x.shape
x_flat = x.view(batch_size, channels, -1).permute(0, 2, 1) # (b, h*w, c)
q = nn.Linear(channels, dim)(x_flat) # (b, h*w, dim)
k = nn.Linear(dim, dim)(context) # (b, len, dim)
v = nn.Linear(dim, dim)(context) # (b, len, dim)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (dim ** 0.5) # (b, h*w, len)
attn_weights = F.softmax(attn_scores, dim=-1) # (b, h*w, len)
attn_output = torch.matmul(attn_weights, v) # (b, h*w, dim)
out = nn.Linear(dim, channels)(attn_output) # (b, h*w, c)
out = out.permute(0, 2, 1).view(batch_size, channels, height, width)
return out
每一行在干什么?
1️⃣ 输入格式
• x: 图像特征图,形状为 (batch_size, channels, height, width)。
• context: 上下文特征,通常是文本 Prompt 编码,形状为 (batch_size, seq_len, dim)。
2️⃣ 将图像特征展平
x_flat = x.view(batch_size, channels, -1).permute(0, 2, 1) # (b, h*w, c)
将二维空间的图像特征展平成一个序列,即把每个像素位置看作一个 token。
3️⃣ 构造 Query / Key / Value
q = nn.Linear(channels, dim)(x_flat) # 查询向量,图像发出请求
k = nn.Linear(dim, dim)(context) # 键向量,文本提供信息索引
v = nn.Linear(dim, dim)(context) # 值向量,文本的实际信息内容
这里 q 是来自图像,k 和 v 是来自文本。
4️⃣ 计算注意力得分
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (dim ** 0.5)
这是标准的 Scaled Dot-Product Attention:
对每个图像位置的 q 与所有文本位置的 k 做点积。
5️⃣ softmax 得到注意力权重
attn_weights = F.softmax(attn_scores, dim=-1)
将得分转化为概率分布,得到每个图像位置对文本各 token 的注意力程度。
6️⃣ 加权求和融合文本特征
attn_output = torch.matmul(attn_weights, v) # (b, h*w, dim)
使用注意力权重对文本的值向量 v 做加权,得到融合文本信息后的图像 token。
7️⃣ 映射回原图像空间
out = nn.Linear(dim, channels)(attn_output)
out = out.permute(0, 2, 1).view(batch_size, channels, height, width)
将融合后的特征映射回原始的通道维度,并 reshape 成 (b, c, h, w) 格式。