- Attention的基本流程是,查询q与键值k相乘获得权重a,a与值v相乘获得注意力值。这篇博客讲的很清晰。
- TopFormer使用多头注意力机制
- 查询qq的每个头中特征图的每个元素有key_dim个特征
- 键值kk和qq维度相同,为了相乘进行了转置
- 值vv和LeViT一致,扩大了每个元素的维度,特征更多
class Attention(torch.nn.Module):
def __init__(self, dim, key_dim, num_heads,
attn_ratio=4,
activation=None,
norm_cfg=dict(type='BN', requires_grad=True),):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
self.to_q = Conv2d_BN(dim, nh_kd, 1, norm_cfg=norm_cfg)
self.to_k = Conv2d_BN(dim, nh_kd, 1, norm_cfg=norm_cfg)
self.to_v = Conv2d_BN(dim, self.dh, 1, norm_cfg=norm_cfg)
self.proj = torch.nn.Sequential(activation(), Conv2d_BN(
self.dh, dim, bn_weight_init=0, norm_cfg=norm_cfg))
def forward(self, x):
B, C, H, W = get_shape(x)
qq = self.to_q(x).reshape(B, self.num_heads, self.key_dim, H * W).permute(0, 1, 3, 2)
kk = self.to_k(x).reshape(B, self.num_heads, self.key_dim, H * W)
vv = self.to_v(x).reshape(B, self.num_heads, self.d, H * W).permute(0, 1, 3, 2)
attn = torch.matmul(qq, kk)
attn = attn.softmax(dim=-1)
xx = torch.matmul(attn, vv)
xx = xx.permute(0, 1, 3, 2).reshape(B, self.dh, H, W)
xx = self.proj(xx)
return xx