Swin-Transformer的Attention Mask笔记

Swin Transformer中将原来全局特征图的自注意力转为求多个windows内的attention,得到局部windows的交互,然后再像CNN卷积一样的思路,通过windows之间交互来获取全局的交互,这样可以避免原transform应用到视觉领域时遇到的问题:1、计算量巨大;2、特征变化大。
其核心设计久是shifted window,如下图所示,这种偏移方式能够获取不同windows的交互,但是也带来了问题,即4原来的4个windows变成了9个,这相当于由在4个windows内应用transformer变成了在9个windows内应用transformer,计算就变得复杂很多。为了简化这个计算,swin-transformer的实现中通过shift mask,使得只需要像原来一样在4个transformer就能完成全部9个windows的局部交互计算。
Swin-Transformer的Attention Mask笔记_第1张图片
实现是是这样的,首先,将特征图通过torch.roll函数将其变换,让9个windows凑成4个windows,然后像正常求4个windowns内部的交互那样执行transformer,到了Q、K的转置相乘得到系数矩阵的时候,需要添加我们的mask,让一部分特征无效(原因是我们将特征图做了torch.roll变换,原来有的9个windows除了4是本来就相邻的,彼此交互有意义,其它的非相邻部分求他们的交互是没道理的,所以让其他部分不生效)。
具体如下:

首先我们对Shift Window后的每个窗口都给上index,并且做一个roll操作(window_size=2, shift_size=-1)

Swin-Transformer的Attention Mask笔记_第2张图片

我们希望在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果。

最后正确的结果如下图所示,每个Q和 K T K^T KT矩阵相乘代表的是torch.roll之后得到的4个伪相邻的子窗口内部求自注意力的过程。需要确保只有原来相邻的窗口才有意义,比如3和5两个窗口原来是不相邻的,只能是3内的位置间有效,5内的位置间有效,其它的无效;而4这个窗口原来就是一个整体,内部的每个相乘都有效。实现的方式就是添加mask,让图中Qmatmul K T K^T KT里数字部分为0,其它部分为-100,这样通过softmax时就只有这些部分有效。
Swin-Transformer的Attention Mask笔记_第3张图片
参考博客:

图解Swin Transformer
Swin-Transformer(原理 + 代码)详解

你可能感兴趣的:(transformer,深度学习,人工智能)