Swin Transformer网络架构、相应改进模块的理解

swin-Transformer

Transformer越来越火,个人感觉基于Transformer来做视觉真是把矩阵用得出神入化!!

Swin Transformer网络架构、相应改进模块的理解_第1张图片

Swin-Transformer相较于VIT改进的方法:

  • SwinT使用类似CNN中层次化构建方法,这样的backbone有助于在此基础上构建检测和分割任务,而VIT中是直接一次性下采样16倍,后面的特征图都是维持这个下采样率不变。
  • 在SwinT中使用Windows Multi-head Self-Attention(WMSA)的概念,在上图中4倍下采样和8倍下采样中,将图片划分成了多个不相交的区域(window),而Multi-head Self-Attention 只在每个独立的window中操作,相对于VIT中直接对全局window进行Multi-head Self-Attention,这样做的目的是为了减少计算量,虽然SwinT提出的WMSA有节约计算量的能力,但是它是牺牲不同window之间的信息传递为代价的,所以作者又针对WMSA的缺点,提出了Shifted Windows Multi-head Self-Attention(SW-MSA),通过这样的方法能够让信息在相邻的窗口中进行信息传递!

SwinT的网络架构图

Swin Transformer网络架构、相应改进模块的理解_第2张图片

  • 首先将图片(H * W * C)输入到Patch Partition模块进行分块,实现方法用四倍下采样的,宽高/4,通道 * 16 ,再通过Linear Embedding层,该层也是通过conv实现的,主要实现的功能降通道(H/4,W/4,16*C)—> (H/4,W/4,C)
  • 然后就是通过四个stage构建不同大小的特征图,除了stage1中先通过Linear Embedding层外,其他三个stage都是通过Patch Merging层来下采样,然后都是堆叠重复的SwinT block,可以从(b)中看到,SwinT block中有两个结构W-MSA和SW-MSA,因为这两个结构都是成对使用的,所以可以看到堆叠的block都是偶数。
  • 最后对于分类网络,后面还会接上一个Layer Norm层,全局池化层以及FC层得到最终的输出。

接下来分别对Patch Merging、W-MSA、SW-MSA以及使用到的相对位置偏执(relative position bias)进行详解,而SwinT block中使用的MLP结构和VIT中结构是一样的

* Patch Merging 详解

Patch Merging跟Yolov5中focus结构差不多,隔一个像素点为一个patch,这样宽高/2,C * 4,然后通过一个Layer Norm层,最后通过一个FC层在Feature Map的深度方向做线性变化(H/2,W/2,C*4)-> (H/2,W/2,C * 2) 。


* W-MSA详解

引入Windows Multi-head Self-Attention模块是为了减少计算量,实现思路:就是将一张图片分成多个window,window很多分patch(像素),每个patch只在该部分的window中做Multi-head Self-Attention。注意: W-MSA中每个window并没有信息的交互。

* SW-MSA详解

作者根据W-MSA中window之间不能进行信息交互做出了改进,提出了SW-MSA。

Swin Transformer网络架构、相应改进模块的理解_第3张图片

如上图所示,左侧为W-MSA在layer L使用,SW-MSA则在L+1层使用,因为从SwinT block中可以看到都是成对使用的,从左右两幅图对比能够发现窗口(Windows)发生了偏移,以这个↘偏移 M/2 个像素,这就解决了不同窗口之间无法进行信息交流的问题!!

window个数有之前4个变成现在的9个了!!!!!!!!!!!!!!!!!

作者采用Efficient batch computation for shifted configuration 这种计算方法,也就是将右图中每个window重新组合成4个window!但是一个问题是不同区域所带的信息不同,如果强制合并在一起的话容易造成信息混乱,作者解决的方式是新区域的像素不是原区域的像素的话,在计算QK后都减去100,这样在softmax后,这部分的像素与其他像素的联系则是0了,**注意:**计算完后还要把数据给挪回到原来的位置。

模型参数配置详解

Swin Transformer网络架构、相应改进模块的理解_第4张图片

  • win.sz 7 * 7表示使用的window大小

  • dim表示feature map的通道深度(或者可以说是token的向量长度)

  • head表示多头注意力模块中head个数

你可能感兴趣的:(读论文,深度学习,transformer,网络,架构)