Swin Transformer算法解读

目录

一、Swin-Transformer整体架构

二、Patch Embedding

三、Swin-Transformer Block

(1)cyclic shift特征图移位操作

(2)window partition/reverse

(3)Window Attention

(4)Attention Mask

(5)merge windows

四、patch merging (down sample)

五、Transformer Block核心逻辑图


本文参考:论文详解:Swin Transformer - 知乎

一、Swin-Transformer整体架构

Swin Transformer算法解读_第1张图片

整个模型采取层次化的设计,除了最后一个BasicLayer外,每个BasicLayer都会在最后通过Patch Merging层缩小输出特征图的分辨率,进行下采样(比如avgPooling池化)操作,像CNN一样逐层扩大感受野,以便获取到全局的信息。

二、Patch Embedding

在进入Block前,需要通过patch_size为4的卷积层将图片切成一个个patch,然后嵌入向量Embedding,将embedding_size转变为96(可以将CV中图片的通道数理解为NLP中token的词嵌入长度)。

这里通过二维卷积层,将stride,kernel_size设置为patch_size大小,设定输出通道来确定嵌入向量的大小。最后将H,W维度展开,并移动到第一维度。

输入的H=W=224是在dataloader阶段的transform中完成图片Height和Width调整的。

三、Swin-Transformer Block

传统的Transformer是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transfomer则将注意力的计算限制在每个窗口内,进而减少了计算量。

Swin Transformer算法解读_第2张图片

Window Attention是在每个窗口下计算注意力的,为了更好地和其他window进行信息交互,Swin Transformer还引入了shifted window 操作。左边是没有重叠的window attention,而右边则是将窗口进行移位的Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题,即window的个数翻倍了,由原本4个窗口变成了9个窗口。在实际代码里,通过对特征图位移,并给Attention设置mask来间接实现的。能在保持原有的windows个数下,最后的计算结果等价。

(1)cyclic shift特征图移位操作

代码里面对特征图移位是通过torch.roll来实现的。

   ->(步骤1)   ->(步骤2)

步骤1:torch.roll(a, shifts=-1, dims=0)

步骤2:torch.roll(b, shifts=-1, dims=1)

如果需要reverse cyclic shift的话只需要把参数shifts设置为对应的正数值。

(2)window partition/reverse

window partition函数是用于对张量划分窗口,指定窗口大小。将原本的张量从B H W C划分成num_windows * B, window_size, window_size, C。其中num_windows=H*W/(window_size*window_size),即窗口的个数。而window reverse函数则是对应的逆过程。

Swin Transformer算法解读_第3张图片

(3)Window Attention

(3.1)计算公式

需要在原始计算Attention的公式中的QK时加入相对位置编码。

Q,K,V.shape=[numWindows*B, num_heads, window_size*window_size, head_dim]

Window_size*window_size即NLP中token的个数

Head_dim = embedding_dim / num_heads,即NLP中token的词嵌入向量的维度

QKT计算出来的Attention张量的形状为[numWindows*B, num_heads, Q_tokens, K_tokens]

其中,Q_tokens=K_tokens=window_size * window_size

(3.2)相对位置索引

首先说下 绝对位置索引

Token的长度为window_size*window_size,当window_size=2时,每个token用二维的坐标(x, y)表示,即标记window_size中每个点的绝对位置索引。

第一个token的query对所有token的attention如下:

Swin Transformer算法解读_第4张图片

因此:

Swin Transformer算法解读_第5张图片

第i行 表示 第i个token的query对所有的token的key的attention

然后说下 相对位置索引

Swin Transformer算法解读_第6张图片

所以QKT的相对位置索引为:

Swin Transformer算法解读_第7张图片

由于最终我们希望使用一维的位置坐标x+y代替二维的位置坐标(x,y),为了避免(1,2)(2,1)两个坐标转为一维时均为3,我们之后对相对位置索引进行了一些线性变换,使得能通过一维的位置坐标唯一映射到一个二维的位置坐标。整体的变换思路示例如下:

Swin Transformer算法解读_第8张图片

上面计算的是相对位置索引,而不是相对位置偏置参数。真正使用到的可训练参数保存在relative position bias table表里的,这个表的长度等于(2*window_size-1) * (2*window_size-1)。这个长度和相对位置索引的最大值是一致的。relative position bias table是需要训练得到的。

(4)Attention Mask

通过设置合理的mask,让shifted window attention在与window attention相同的窗口个数下,达到等价的计算结果。

首先我们对Shift Window后的每个窗口都给上index,如下图所示:

Swin Transformer算法解读_第9张图片

第一次shift window的时候,H=W=56,以window_size=7划分窗口,则可以划分8*8=64个窗口。Shift_size = window_size // 2 = 3。

假设window_size=2,shift_size=1,则可以得到如下结果:

Swin Transformer算法解读_第10张图片

我们在计算Attention的时候,让具有相同index QK进行计算,而忽略不同index QK计算结果。

Swin Transformer算法解读_第11张图片

(5)merge windows

四、patch merging (down sample)

该模块的作用是做降采样,用于缩小分辨率,调整通道数进而形成层次化的设计,同时也能节省一定运算量。

每次降采样是2倍,因此在行方向和列方向上,间隔2选取元素。

然后拼接在一起作为一整个张量,最后展开。此时通道数维度会变成原先的4倍(因为H,W各缩小2倍),此时再通过一个全连接再调整通道维度为原来的2倍。

下面是一个示意图(输入张量N=1,H=W=8, C=1)

Swin Transformer算法解读_第12张图片

Swin Transformer算法解读_第13张图片

五、Transformer Block核心逻辑图

Swin Transformer算法解读_第14张图片

Swin Transformer算法解读_第15张图片

你可能感兴趣的:(神经网络,transformer,算法,深度学习)