其中Block是最核心的,对attention的计算方法进行了改进,一个Transformer Blocks包含两个部分,一个是基于窗口的注意力计算——W-MSA,另外一个是窗口滑动后重新计算注意力——SW-MSA,它俩串联在一起就是一个block
输入:图像数据(224,224,3)
输出:(3136,96)相当于序列长度是3136个,每个的向量是96维特征
通过卷积得到,Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
3136也就是 (224/4) * (224/4)得到的,也可以根据需求更改卷积参数
输入:特征图(56,56,96)
默认窗口大小为7,所以总共可以分成8*8个窗口
输出:特征图(64,7,7,96)
之前的单位是序列,现在的单位是窗口(共64个窗口)
对得到的窗口,计算各个窗口自己的自注意力得分
qkv三个矩阵放在一起了:(3,64,3,49,32)
3个矩阵,即q,k,v三个矩阵,每个矩阵有64个窗口,heads为3,窗口大小7*7=49(即每个窗口有49个token),每个head特征96/3=32
通过得到的attention计算得到新的特征(64,49,96),总共64个窗口,每个窗口7*7的大小,每个点对应96维向量,window_reverse就是通过reshape操作还原回去(56,56,96)
为什么要shift?原来的window都是算自己内部的,这样就会导致只有内部计算,没有它们之间的关系,容易上模型局限在自己的小领地,可以通过shift操作来改善。
例如:假设我们有8个窗口,分别为1,2,3,4,5,6,7,8
两两一组合并,第一次(1,2),(3,4),(5,6),(7,8)
第二次,滑动窗口,假设为strides=1,为(2,3),(4,5),(6,7),(8,1)
依次滑动
位移就是像素点移动了一下位置 :
窗口移动后,带来了计算量的问题,例如原来4个,现在9个了,计算量怎么解决呢?
首先得到新窗口,并对其做位移操作
在计算时,只需要计算自己窗口的,其他的都都是无关的,比如说对于7,1,我们只取对角线上自己需要的结果,其他部分全部mask掉,让其值为负无穷即可,最后再经过softmax操作,输出结果同样为(56,56,96),计算完特征后需要对图像进行还原,也就是还原平移
下采样操作,但是不同于池化,这个相当于间接的 (对H和W维度进行间隔采样后拼接在一起,得到H/2,W/2,C*4)
一次下采样后(3136->784也就是56*56->28*28),继续走这两个模块,也就是各个stage的流程,最后根据任务来选择合适的head层即可(分类,分割,检测等)