在过往的十年里,卷积神经网络CNN支配了计算机视觉的技术发展,下一个十年,或许将由Transformer接过接力棒了?
还记得在2020年的年末回顾中(《[年终AI大事件回顾]我眼中的2020年TOP5》),我们梳理了2020年的TOP5的AI领域研究工作。其中被排在第四位的,便是一篇基于Transformer的视觉领域的工作。
当时给出的评语是:
在更复杂的图像任务上的效果还有待探索和加强,但这一工作让我们看到了“跨界”的可能,相信未来会有更多在视觉上的工作基于Transformer而展开。
当时想的是:除了单纯用Transformer的结构完全替代CNN解决CV问题之外,是时候把CNN和Transformer结合在一起了吧。类似的工作比如说DETR,用CNN提取图像特征,在之后接了Transformer的encoder和decoder。
2021年1月27日在arxiv上发的一篇文章Bottleneck Transformers for Visual Recognition同样是采用了CNN+Transformer,但在我看来,似乎是更加elegant的做法:
1. 将Transformer的Self-attention融入了一个CNN的backbone中,而非叠加;
2. 具体来说是在ResNet的最后三个bottleneck blocks中用MHSA(多头自注意力层,Multi-Head Self-attention)替换了原本的3x3卷积(下图)。这些新的blocks被命名为BoT blocks,这一新的网络被命名为BotNet。
这样做的好处是显而易见的:
1. 能够利用成熟的、经过检验的CNN网络结构提取特征,CNN在视觉领域是有一些先验或者inductive biases的;
2. 用CNN对输入图像做了下采样后,再由self-attention进行运算,相比于直接使用self-attention在原图上处理,能够降低运算量;
3. 这样的设计能够与其他方法结合,例如可能作为backbone应用于DETR中。
在MHSA层中,特征输入X经过WQ, WK, WV三个矩阵映射成q,k,v,分别代表query,key和value。Self-attention的操作一般是对qkv进行计算。
Multi-Head(多头)体现在对每一个head都有不同的WQ, WK, WV,完成对特征输入的映射,并进行上述自注意力的运算,以拓展模型在不同的表示空间里学习。
此外,由于上述的multi-head和self-attention操作中,没有引入与图像中位置相关的信息。因此,引入了相对位置编码(relative position encoding):R_h和R_w,分别表征高度和宽度编码。
这三个操作结合起来便是下图中的结构,对于熟悉Transformer的同学而言,其实并没有太多特殊的操作。
这里引用一段在Pytorch上对这一Attention层的非官方实现[3]:
class Attention(nn.Module):
def __init__(
self,
*,
dim,
fmap_size,
heads = 4,
dim_head = 128,
rel_pos_emb = False
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = heads * dim_head
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
rel_pos_class = AbsPosEmb if not rel_pos_emb else RelPosEmb
self.pos_emb = rel_pos_class(fmap_size, dim_head)
def forward(self, fmap):
heads, b, c, h, w = self.heads, *fmap.shape
q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v))
q *= self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
sim += self.pos_emb(q)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return out
其中位置编码的做法(求qr^T):
def rel_to_abs(x):
"""
Converts relative indexing to absolute.
Input: [bs, heads, length, 2*length - 1]
Output: [bs, heads, length, length]
"""
b, h, l, _, device, dtype = *x.shape, x.device, x.dtype
dd = {'device': device, 'dtype': dtype}
col_pad = torch.zeros((b, h, l, 1), **dd)
x = torch.cat((x, col_pad), dim = 3)
flat_x = rearrange(x, 'b h l c -> b h (l c)')
flat_pad = torch.zeros((b, h, l - 1), **dd)
flat_x_padded = torch.cat((flat_x, flat_pad), dim = 2)
final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
final_x = final_x[:, :, :l, (l-1):]
return final_x
def relative_logits_1d(q, rel_k):
"""
Compute relative logits along one dimenion.
`q`: [bs, heads, height, width, dim]
`rel_k`: [2*width - 1, dim]
"""
b, heads, h, w, dim = q.shape
logits = einsum('b h x y d, r d -> b h x y r', q, rel_k)
logits = rearrange(logits, 'b h x y r -> b (h x) y r')
logits = rel_to_abs(logits)
logits = logits.reshape(b, heads, h, w, w)
logits = expand_dim(logits, dim = 3, k = h)
return logits
class RelPosEmb(nn.Module):
def __init__(
self,
fmap_size,
dim_head
):
super().__init__()
height, width = pair(fmap_size)
scale = dim_head ** -0.5
self.fmap_size = fmap_size
self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)
def forward(self, q):
h, w = self.fmap_size
q = rearrange(q, 'b h (x y) d -> b h x y d', x = h, y = w)
rel_logits_w = relative_logits_1d(q, self.rel_width)
rel_logits_w = rearrange(rel_logits_w, 'b h x i y j-> b h (x y) (i j)')
q = rearrange(q, 'b h x y d -> b h y x d')
rel_logits_h = relative_logits_1d(q, self.rel_height)
rel_logits_h = rearrange(rel_logits_h, 'b h x i y j -> b h (y x) (j i)')
return rel_logits_w + rel_logits_h
下图定量分析了BoTNets在ImageNet数据集上的性能:top-1 acc与EfficientNet B7匹敌,但运算速度更快。
定性分析分别使用ResNet50(下图左)和BoTNet50(下图右)作为MaskRCNN的backbone网络:
总结一下,这篇论文虽然做了一些不一样的工作,但却只是开了个头。
在我看来,未来必然的发展是:
1. 改进Transformer的计算和存储效率,使之能适用于低算力平台;
2. CNN+Transformer的结构会存在一段时间,但最终被Transformer完全取代;
3. 最终的形态是类似于在NLP领域的应用——在大量图像数据上训练出的Transformer网络,能够适用于多个下游视觉任务。
最后,在这里期待一下,下一个“造福人类”、引领下一个十年的方法是否会出现在这个领域呢。
参考资料:
[1]https://arxiv.org/pdf/2010.11929.pdf
[2] https://arxiv.org/abs/2101.11605
[3] https://github.com/lucidrains/bottleneck-transformer-pytorch
- END -
新朋友们可以看看我过往的相关文章⬇
【相关推荐阅读】
[年终AI大事件回顾]我眼中的2020年TOP5
模式识别学科发展报告丨前言
梯度手术-多任务学习优化方法[NeurIPS 2020]
欢迎分享/转载,并注明出处。