Swin-Transformer 从数据尺度变换角度解析

        最近在看Swin-Transformer时,发现在网络中存在许多的数据尺度变换。本博文主要是从图像数据的输入逐步分析一张RGB图像在Swim-Transformer中是如何进行尺度变换的。至于Swin-Transformer网络的详细内容,本博文不会展开说明,可参考Swin-T .

        下面先给出Swin-Transformer的网络架构图:

Swin-Transformer 从数据尺度变换角度解析_第1张图片

        这里以Swin-Transformer-Tiny版本为例。假设,输入图像的尺寸为224x224x3(H,W,C)。

        首先,将图片输入到Patch Partition模块中进行分块,即每4x4相邻的像素为一个Patch,然后在通道方向展平(flatten)。即每个patch有4x4=16个像素,然后每个像素有R、G、B三个值所以展平后是16x3=48,所以通过Patch Partition后图像shape由 [224,224, 3]变成了 [56, 56, 48]。然后在通过Linear Embeding层对每个像素的channel数据做线性变换,由48变成96,即图像shape再由 [56, 56, 48]变成了 [56,56, 96]。其实在源码中Patch Partition和Linear Embeding就是直接通过一个卷积操作实现的。

        其次,进入第一个Swin-Transformer-block,注意,这里的block有两种结构,区别在于一个使用了W-MSA结构,一个使用了SW-MSA结构,并且这两种不同的block是成对组成的,所以,你会发现堆叠的Swin Transformer Block的次数都是偶数。在Swin-Transformer-block中,我们要对输入的feature-map按照Window_size的大小进行一个一个的划分。例如,经过上一层,我们输入的feature-map为[56x56x96],这里设置Window_size的大小为7x7,那么会被分为8x8总共64个shape为[7x7x96]的Swin窗口,这一步骤在源码中称为window_partition,此步骤的图像shape变换为[56x56x96]-->[64x7x7x96]--->[64x49x96], 64是窗口个数,不是数据维度,所以将3维数据转为2维数据,可以使用transformer_attention。这里还需要注意,我们在每次进入block_stage时,都会创建一个feature_mask,通过它,我们可以实现SW-MSA,具体较复杂,这里不展开说。做完attention操作后(这里还有相对位置偏执也不展开说了)再将分开的Swin窗口拼接成一个完整的feature_map,这一操作源码中称作window_reverse,此步骤的图像shape变换为[64x49x96]-->[64x7x7x96]-->[56x56x96]。在Swin-Transformer-block中还有一个MLP结构,此结构的shape变化较为简单[56x56x96]-->[56x56x384]-->[56x56x96]。所以,Swin-Transformer-block的输入尺寸和输出尺寸是一致的。总结下来,shape的变化过程[1,56,56,96](B,H,W,C)-->[64,7,7,96]-->[64,49,96]-->[64,7,7,96]-->[1,56,56,96]-->[1,56,56,384]-->[1,56,56,96]。上述过程为第一个stage中第一个Swin-Transformer-block的尺度变换,stage中其他Swin-Transformer-block的尺度变化可依次类推。在经过stage1以后shape为[56x56x96]。

        接着,通过Patch Merging,具体操作不展开说明,shape变化从[56x56x96]-->[28x28x192],其他的Patch Merging shape变换也可以类推。

        最终,在经过4个stage后,原始尺寸从[224x224x3]-->[7x7x768],在经过一个全局平均池化操作shape[7x7x768]-->[1x1x768],最后,经过一个线性分类器进行分类预测输出[768]-->[num_classes]

        这就是Swin-Transfomer网络大致的数据尺度变换。

        后续有时间再完整的补充下。

 
  
 
  
 
  


 

 

你可能感兴趣的:(transformer,深度学习,人工智能,Swin-Trans)