swin transformer详解

1.解决问题

  • 图像中像素点太多了,如果需要更多的特征就必须构建很长的序列
  •  越长的序列注意力的计算肯定越慢,这就导致了效率问题
  • 能否用窗口和分层的形式来替代长序列的方法呢?这就是它的本质
  • CNN中经常提到感受野,transformer在分层中进行体现

2.整体网络架构

  1.   得到各Pathch特征构建序列;
  2. 分层计算attention(逐步下采样过程)

        其中Block是最核心的,对attention的计算方法进行了改进,一个Transformer Blocks包含两个部分,一个是基于窗口的注意力计算——W-MSA,另外一个是窗口滑动后重新计算注意力——SW-MSA,它俩串联在一起就是一个block

swin transformer详解_第1张图片

 

Patch Embedding

输入:图像数据(224,224,3)

输出:(3136,96)相当于序列长度是3136个,每个的向量是96维特征

通过卷积得到,Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))

3136也就是 (224/4) * (224/4)得到的,也可以根据需求更改卷积参数

window_partition

输入:特征图(56,56,96)

默认窗口大小为7,所以总共可以分成8*8个窗口

输出:特征图(64,7,7,96)

之前的单位是序列,现在的单位是窗口(共64个窗口)         

W-MSA(Window Multi-head Self Attention)

        对得到的窗口,计算各个窗口自己的自注意力得分

        qkv三个矩阵放在一起了:(3,64,3,49,32)

        3个矩阵,即q,k,v三个矩阵,每个矩阵有64个窗口,heads为3,窗口大小7*7=49(即每个窗口有49个token),每个head特征96/3=32

window_reverse

        通过得到的attention计算得到新的特征(64,49,96),总共64个窗口,每个窗口7*7的大小,每个点对应96维向量,window_reverse就是通过reshape操作还原回去(56,56,96)

SW-MSA(Shifted Window) 

        为什么要shift?原来的window都是算自己内部的,这样就会导致只有内部计算,没有它们之间的关系,容易上模型局限在自己的小领地,可以通过shift操作来改善。

        例如:假设我们有8个窗口,分别为1,2,3,4,5,6,7,8

        两两一组合并,第一次(1,2),(3,4),(5,6),(7,8)

        第二次,滑动窗口,假设为strides=1,为(2,3),(4,5),(6,7),(8,1)

        依次滑动

位移中的细节

        位移就是像素点移动了一下位置 :

swin transformer详解_第2张图片

        窗口移动后,带来了计算量的问题,例如原来4个,现在9个了,计算量怎么解决呢? 

        swin transformer详解_第3张图片 

 swin transformer详解_第4张图片

首先得到新窗口,并对其做位移操作

swin transformer详解_第5张图片 

        在计算时,只需要计算自己窗口的,其他的都都是无关的,比如说对于7,1,我们只取对角线上自己需要的结果,其他部分全部mask掉,让其值为负无穷即可,最后再经过softmax操作,输出结果同样为(56,56,96),计算完特征后需要对图像进行还原,也就是还原平移

swin transformer详解_第6张图片 PatchMerging

下采样操作,但是不同于池化,这个相当于间接的 (对H和W维度进行间隔采样后拼接在一起,得到H/2,W/2,C*4)

 

swin transformer详解_第7张图片 

分层计算 

        一次下采样后(3136->784也就是56*56->28*28),继续走这两个模块,也就是各个stage的流程,最后根据任务来选择合适的head层即可(分类,分割,检测等)

swin transformer详解_第8张图片

 

 

        

 

        

你可能感兴趣的:(transformer,深度学习,人工智能,python,目标检测,transformer)