Swin Transformer中的数据形状梳理

Swin Transformer中的数据形状梳理

  • 关键零件内部数据形状
    • PatchEmbed层
    • Swin-transformer层
    • PatchMerging层
  • 整体结构

关键零件内部数据形状

PatchEmbed层

原始输入x形状为(b, 3, 224, 224)
PatchEmbed层完成embedding的过程
第一步还是对图片进行Patch分割与embedding编码
patch_size=4, embed_dim=96,使用96个卷积核44的卷积层以步长为4卷积进行
得到输出x形状为(b, 56
56, 96) -> (b, 3136, 96)

输入形状为 (b, 3136, 96)

Swin-transformer层

输入形状为 (b, 3136, 96)
Swin-tran层
先变为(b, 56, 56, 96)
如果是SW-MSA需进行位移操作,操作完成后仍是(b, 56, 56, 96)
经过window_partition层
经过窗口分割窗口尺寸为7,因此最终有64个窗口,因而分割后变成为(b64, 7, 7, 96),之前5656的矩阵中每个点96个编码值对应原图中一个44的一个patch,经过窗口分割,此时77的矩阵中每个点仍有96个编码值对应原图中仍是一个44的patch,即不改变感受野,且自注意力机制发生在这77=49个patchs之间。
输出为(b64, 7, 7, 96)
先变为(b
64, 49, 96)
正式开始窗口自注意力操作层
对于形状是(b64, 49, 96)的输入数据先通过mlp生成QKV矩阵(b64,49,963),则经过拆分后Q,K,V为(b64,49,96),由于有3个头,Q,K,V为(b64, 3, 49, 32),即Q,K,V的原始长度与输入embedding编码的长度是一致的,然后对各个头进行均分得到多头QKV。QK计算注意力分数形状为(b64, 3, 49, 49),相对位置编码得到注意力分数偏置B形状为(b64, 3, 49, 49),相加后即得到携带相对位置信息的注意力分数,经过softmax后(# 如果是SW-MSA需进行带掩码的softma)与V相乘即可,最终形状为(b64, 3, 49, 32),最终还原回(b64, 49, 96),再经过一段mlp
输出为(b
64, 49, 96)
先变为(b*64, 7, 7, 96)
经过window_reverse层
输出为(b, 56, 56 , 96)
如果是SW-MSA需进行反位移操作,操作完成后仍是(b, 56, 56, 96)
变为(b, 3136, 96)
经过FFN层
输出为(b, 3136 , 96)
输入形状为 (b, 3136, 96)

PatchMerging层

输入形状为(b, 3136 , 96)
经过PatchMerging层
先变为(b, 56, 56 , 96)
对patchs进行混合,选择偶行偶列,偶行奇列,奇行偶列,奇行奇列,生成四个形如(b, 28, 28 , 96),贴合为(b, 28, 28 , 384)的新向量,变为(b, 2828 , 384)最终经过mlp层输出为(b, 784 , 192)
经过此层28
28的矩阵每个点192个向量值,代表原图中4个4*4的patch,即感受野成倍扩大
输出为(b, 784, 192)

整体结构

输入(b, 3, 224, 224)
Embedding编码
输出(b, 3136, 96)
此时3136行,每行的感受野为原图中的一个patch
W-MSA组成的Swin-tran与SW-MSA组成的Swin-tran组成一个Basiclayer
经过一个Basiclayer
输出(b, 3136, 96)
经过一个PatchMerging
输出为(b, 784, 192)
此时784行,每行的感受野为原图中的四个patch
经过一个Basiclayer
输出为(b, 784, 192)
经过一个PatchMerging
输出为(b, 196, 384)
此时196行,每行的感受野为原图中的十六个patch
经过三个Basiclayer
输出为(b, 196, 384)
经过一个PatchMerging
输出为(b, 49, 768)
此时49行,每行的感受野为原图中的64个patch
经过一个Basiclayer
输出为(b, 49, 768)
经过一个全局池化层,768视为图层数则
输出为(b, 768)
经过全连接层获得最终输出
输出为(b, classes_num)

你可能感兴趣的:(Vision,Transformer,transformer,深度学习,人工智能)