【Transformer】9、CrossFormer:A versatile vision transformer based on cross-scale attention

文章目录

    • 一、背景
    • 二、动机
    • 三、方法
      • 3.1 Cross-scale Embedding Layer(CEL)
      • 3.2 Cross-former Block
        • 3.2.1 Long Short Distance Attention(LSDA)
        • 3.2.2 Dynamic position bias (DPB)
      • 3.3 CrossFormer 的变体
    • 四、效果
    • 五、代码

【Transformer】9、CrossFormer:A versatile vision transformer based on cross-scale attention_第1张图片

论文链接:CrossFormer:A versatile vision transformer based on cross-scale attention

代码链接:https://github.com/cheerss/CrossFormer

一、背景

Transformer 在计算机视觉中已经有了一些成功的应用,这些方法大都将输入图像切分成 patches,然后将这些 patch 编码成序列的特征,在 transformer 内部使用 attention 来建立编码后特征之间的关系。但使用原始的 self-attention 计算量相比 NLP 来说非常大,于是也有一些人对 self-attention 的方法做了一些改进。

二、动机

对于视觉任务,一个图像中不同目标的尺度是有很大不同的,所以如果需要对两个大小相差较大的目标建立关系的话,需要使用跨尺度的 attention,但是很少有方法能很好的建立不同尺度特征之间的 attention,原因有两个:

  • 其一是每层的输入编码特征都是相同尺度的,没有跨尺度的特征
  • 其二是一些方法为了效率牺牲了小尺度编码的特征。

基于此,作者就构建了一个 crossformer 结构。

三、方法

总体结构如图 1 所示,是金字塔结构,总共有四个 stages,每个 stage 的组成:

  • 一个 cross-scale embedding layer(CEL)

    CEL 接收上一个 stage 的输出(首个接收原始图像),输出 cross-scale embedding。除了 stage 1 以外, CEL 会把输入 embedding 的元素数量降低为输入的 1/4,维度增加为输入的 2 倍。

  • 多个 CrossFormer blocks

【Transformer】9、CrossFormer:A versatile vision transformer based on cross-scale attention_第2张图片
1、Cross-scale Embedding Layer(CEL)

作者使用金字塔结构的 Transformer,即会将模型分为多个 stage,在每个stage 之前都会使用 CEL,接收上个 stage 的输出作为输入,使用不同尺度的核进行 patch 选择。然后把这些 embedding 进行线性影射后 concat 起来,可以看成单个尺度的patch。

2、Long Short Distance Attention(LSDA)

作者也提出了一个传统 self-attention 的替代品 LSDA,LSDA不会损坏小尺度或大尺度特征,所以能够进行跨尺度的信息交互。LSDA 中作者将 self-attention 分成了两个部分

  • short-distance attention (SDA):建立和目标 embedding 近距离 embedding 的注意力特征
  • long-distance attention (LDA):建立和目标 embedding 远距离 embedding 的注意力特征

3、作者引入可学习模块 dynamic position bias (DPB) 来进行位置表达,输入为两个 embedding 之间的距离,输出为它们之间的 position bias。之前的 relative position bias(RPB)虽然高效,但只适合输入图像大小一致的情况,不适合于目标检测任务。

3.1 Cross-scale Embedding Layer(CEL)

CEL 在每个 stage 都会被用来生成 stage 的输入 embedding,如图2所示,在 stage-1 之前,使用原图作为首个 CEL 的输入,使用四个不同大小的核来进行 patch 抽取,即在相同的位置(中心点)使用四个大小不同的核进行 patch 抽取,然后经过投影后concat起来,得到 embedding 特征。

但还有一个问题就是每个不同大小的 patch 投影后的特征维度如何选取,已知大 kernel 容易带来大的计算复杂度,所以作者给大 kernel 使用了低维输出,小 kernel 使用了高维输出。

接收上个 stage 的输出作为输入,使用不同尺度的核进行 patch 选择。然后把这些 embedding 进行线性影射后 concat 起来,可以看成单个尺度的patch。

【Transformer】9、CrossFormer:A versatile vision transformer based on cross-scale attention_第3张图片

3.2 Cross-former Block

每个 Cross-former Block 都包括一个 SDA(或一个 LDA) + 一个 MLP,也就是 SDA 和 LDA 不会同时出现在 Cross-former Block 里边。

【Transformer】9、CrossFormer:A versatile vision transformer based on cross-scale attention_第4张图片

3.2.1 Long Short Distance Attention(LSDA)

作者将 self-attention 分成了两个部分

  • short-distance attention (SDA):建立和目标 embedding 近距离 embedding 的注意力特征
  • long-distance attention (LDA):建立和目标 embedding 远距离 embedding 的注意力特征

1、对于 SDA,每个 G × G G \times G G×G 的相邻 embedding 都被聚合起来了,图3a展示了 G = 3 G=3 G=3 的情况

2、对于 LDA,输入为 S × S S\times S S×S,其 embedding 都会使用固定的间隔 I I I 被采样。如图3b所示, I = 3 I=3 I=3,所有红色区域的 embedding 属于一个 group,黄色的是另一个 group。group 的宽高都为 G = S I G=\frac{S}{I} G=IS,此处 G = 3 G=3 G=3

3、在聚合(group)之后,SDA 和 LDA 都会在每个 group 内使用传统的 self-attention,计算复杂度会从 O ( S 4 ) O(S^4) O(S4) 降低为 O ( S 2 G 2 ) O(S^2G^2) O(S2G2)

在图3b中,作者绘制了两个 embedding 的 patch,两个 embedding 的小尺度 patch 是不相邻的,在没有大尺度 patch 的帮助下很难判断该两者的关系。所以,如果两个 embedding 只包含小的 patches 时,很难建立它们俩的关系。相反的,相邻的大尺度 patch 能够提供足够的上下文信息来连接这两个 embedding。所以,跨尺度的 attention 能够较好的解决主要由大尺度 patch 主导的问题。

3.2.2 Dynamic position bias (DPB)

Relative position bias(RPB)通常被用来表示 embedding 的相对位置,是加在 attention 的一个偏置。

在这里插入图片描述

虽然高效,但只适合输入图像大小一致的情况,不适合于目标检测任务。

所以作者提出了一个 DPB,结构如图3c所示,其输入维度为 2( Δ x i , j , Δ y i , j \Delta x_{i,j}, \Delta y_{i,j} Δxi,j,Δyi,j,即第 i 个和第 j 个 embedding 的坐标距离)。由三个全连接层、一个 layer norm、一个 ReLU 构成,中间层维度为 D / 4 D/4 D/4 D D D 为 embedding 的维度。
在这里插入图片描述

3.3 CrossFormer 的变体

表 1 展示了 CrossFormer 的变体,包括 T/S/B/L,分别对应 tiny,small,base,large。
【Transformer】9、CrossFormer:A versatile vision transformer based on cross-scale attention_第5张图片

四、效果

【Transformer】9、CrossFormer:A versatile vision transformer based on cross-scale attention_第6张图片

【Transformer】9、CrossFormer:A versatile vision transformer based on cross-scale attention_第7张图片
【Transformer】9、CrossFormer:A versatile vision transformer based on cross-scale attention_第8张图片

【Transformer】9、CrossFormer:A versatile vision transformer based on cross-scale attention_第9张图片

五、代码

下载代码后,可以使用下面的方式来进行简单调用,看看 crossformer 是怎么实现的。

import torch
from models.crossformer import CrossFormer

model = CrossFormer(img_size=224,
                    patch_size=[4, 8, 16, 32],
                    in_chans=3,
                    num_classes=1000,
                    embed_dim=64,
                    depths=[ 1, 1, 8, 6 ],
                    num_heads=[ 2, 4, 8, 16 ],
                    group_size=[ 7, 7, 7, 7 ],
                    mlp_ratio= 4,
                    qkv_bias=True,
                    qk_scale=None,
                    drop_rate=0.0,
                    drop_path_rate=0.1,
                    ape=False,
                    patch_norm=True,
                    use_checkpoint=False,
                    merge_size=[[2, 4], [2,4], [2, 4]],)

model.eval()
input = torch.randn(1, 3, 224, 224)
output = model(input)

1、patch embedding 的实现:

输入为原图,输出为经过不同大小的卷积核卷积后的结果,然后拼接起来,输入给 crossformer block。

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        # patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[0] // patch_size[0]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.projs = nn.ModuleList()
        for i, ps in enumerate(patch_size):
            if i == len(patch_size) - 1:
                dim = embed_dim // 2 ** i
            else:
                dim = embed_dim // 2 ** (i + 1)
            stride = patch_size[0]
            padding = (ps - patch_size[0]) // 2
            self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        xs = []
        for i in range(len(self.projs)):
            tx = self.projs[i](x).flatten(2).transpose(1, 2) #[1,32,56,56],[1,16,56,56],[1,8,56,56],[1,8,56,56]
            xs.append(tx)  # B Ph*Pw C #xs[0]=[1, 3136, 32], xs[1]=[1, 3136, 16], xs[2]=[1, 3136, 8], xs[3]=[1, 3136, 8]
        x = torch.cat(xs, dim=2) # [1, 3136, 64]
        if self.norm is not None:
            x = self.norm(x)
        return x
PatchEmbed(
  (projs): ModuleList(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4))
    (1): Conv2d(3, 16, kernel_size=(8, 8), stride=(4, 4), padding=(2, 2))
    (2): Conv2d(3, 8, kernel_size=(16, 16), stride=(4, 4), padding=(6, 6))
    (3): Conv2d(3, 8, kernel_size=(32, 32), stride=(4, 4), padding=(14, 14))
  )
  (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)

2、CrossFormer Stage

lsda_flag=0 # 表示在该 block 里边使用 SDA
lsda_flag=1 # 表示在该 block 里边使用 LDA

stage0:输入 56x56,64d,1 个 SDA
stage1:输入 28x28,128d,1 个 SDA
stage 2:输入 14x14,156d,4 个 SDA 和 4 个 LDA 交替
stage 3:输入 7x7,512d,6 个 SDA

SDA 的实现:

G = self.group_size          # 7
if self.lsda_flag == 0: # 0 for SDA
    x = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)  # [1, 8, 7, 8, 7, 64] -> [1, 8, 8, 7, 7, 64]
else: # 1 for LDA
    x = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)  # [1, 7, 2, 7, 2, 256] -> [1, 2, 2, 7, 7, 256]
x = x.reshape(B * H * W // G**2, G**2, C) # [64, 49, 64] for SDA in layer1, [4, 49, 256] for LDA in layer2

# multi-head self-attention
x = self.attn(x, mask=self.attn_mask)     # nW*B, G*G, C $ [64, 49, 64]

以 stage 0 为例,说明 SDA:

输入为 56x56,每行每列都分为 7 个 group,一共 49 个 group,每个 group 元素为 64 个,然后在每个 group 间做 attention。

以 stage 2 为例,说明 LDA:

输入为 14x14,每行每列都分为 7 个 group,跨一行一列取一个元素作为一个 group 内的元素,每个 group 元素为 4 个,然后在每个 group 间做 attention。

class Stage(nn.Module):
    """ CrossFormer blocks for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        group_size (int): variable G in the paper, one group has GxG embeddings
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, group_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
                 patch_size_end=[4], num_patch_size=None):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList()
        for i in range(depth):
            lsda_flag = 0 if (i % 2 == 0) else 1
            self.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, group_size=group_size,
                                 lsda_flag=lsda_flag,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer,
                                 num_patch_size=num_patch_size))

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, 
                                         patch_size=patch_size_end, num_input_patch_size=num_patch_size)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops

crossformer 结构:

ModuleList(
  (0): Stage(
    dim=64, input_resolution=(56, 56), depth=1
    (blocks): ModuleList(
      (0): CrossFormerBlock(
        dim=64, input_resolution=(56, 56), num_heads=2, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=64, group_size=(7, 7), num_heads=2
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=4, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=4, out_features=4, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=4, out_features=4, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=4, out_features=2, bias=True)
            )
          )
          (qkv): Linear(in_features=64, out_features=192, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=64, out_features=64, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=64, out_features=256, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=256, out_features=64, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (downsample): PatchMerging(
      input_resolution=(56, 56), dim=64
      (reductions): ModuleList(
        (0): Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
        (1): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      )
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
  )
  (1): Stage(
    dim=128, input_resolution=(28, 28), depth=1
    (blocks): ModuleList(
      (0): CrossFormerBlock(
        dim=128, input_resolution=(28, 28), num_heads=4, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=128, group_size=(7, 7), num_heads=4
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=8, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=8, out_features=8, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=8, out_features=8, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=8, out_features=4, bias=True)
            )
          )
          (qkv): Linear(in_features=128, out_features=384, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=128, out_features=128, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=128, out_features=512, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=512, out_features=128, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (downsample): PatchMerging(
      input_resolution=(28, 28), dim=128
      (reductions): ModuleList(
        (0): Conv2d(128, 128, kernel_size=(2, 2), stride=(2, 2))
        (1): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
  (2): Stage(
    dim=256, input_resolution=(14, 14), depth=8
    (blocks): ModuleList(
      (0): CrossFormerBlock(
        dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=256, group_size=(7, 7), num_heads=8
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=16, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=8, bias=True)
            )
          )
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): CrossFormerBlock(
        dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=1, mlp_ratio=4
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=256, group_size=(7, 7), num_heads=8
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=16, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=8, bias=True)
            )
          )
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (2): CrossFormerBlock(
        dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=256, group_size=(7, 7), num_heads=8
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=16, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=8, bias=True)
            )
          )
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (3): CrossFormerBlock(
        dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=1, mlp_ratio=4
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=256, group_size=(7, 7), num_heads=8
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=16, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=8, bias=True)
            )
          )
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (4): CrossFormerBlock(
        dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=256, group_size=(7, 7), num_heads=8
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=16, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=8, bias=True)
            )
          )
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (5): CrossFormerBlock(
        dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=1, mlp_ratio=4
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=256, group_size=(7, 7), num_heads=8
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=16, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=8, bias=True)
            )
          )
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (6): CrossFormerBlock(
        dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=256, group_size=(7, 7), num_heads=8
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=16, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=8, bias=True)
            )
          )
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (7): CrossFormerBlock(
        dim=256, input_resolution=(14, 14), num_heads=8, group_size=7, lsda_flag=1, mlp_ratio=4
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=256, group_size=(7, 7), num_heads=8
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=16, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=16, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=16, out_features=8, bias=True)
            )
          )
          (qkv): Linear(in_features=256, out_features=768, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=256, out_features=256, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=256, out_features=1024, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=1024, out_features=256, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (downsample): PatchMerging(
      input_resolution=(14, 14), dim=256
      (reductions): ModuleList(
        (0): Conv2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
        (1): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      )
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
  )
  (3): Stage(
    dim=512, input_resolution=(7, 7), depth=6
    (blocks): ModuleList(
      (0): CrossFormerBlock(
        dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=512, group_size=(7, 7), num_heads=16
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=32, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=16, bias=True)
            )
          )
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): CrossFormerBlock(
        dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=512, group_size=(7, 7), num_heads=16
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=32, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=16, bias=True)
            )
          )
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (2): CrossFormerBlock(
        dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=512, group_size=(7, 7), num_heads=16
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=32, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=16, bias=True)
            )
          )
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (3): CrossFormerBlock(
        dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=512, group_size=(7, 7), num_heads=16
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=32, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=16, bias=True)
            )
          )
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (4): CrossFormerBlock(
        dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=512, group_size=(7, 7), num_heads=16
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=32, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=16, bias=True)
            )
          )
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (5): CrossFormerBlock(
        dim=512, input_resolution=(7, 7), num_heads=16, group_size=7, lsda_flag=0, mlp_ratio=4
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          dim=512, group_size=(7, 7), num_heads=16
          (pos): DynamicPosBias(
            (pos_proj): Linear(in_features=2, out_features=32, bias=True)
            (pos1): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos2): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=32, bias=True)
            )
            (pos3): Sequential(
              (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
              (1): ReLU(inplace=True)
              (2): Linear(in_features=32, out_features=16, bias=True)
            )
          )
          (qkv): Linear(in_features=512, out_features=1536, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=512, out_features=512, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
        (drop_path): DropPath()
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
  )
)

3、总体代码

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class DynamicPosBias(nn.Module):
    def __init__(self, dim, num_heads, residual):
        super().__init__()
        self.residual = residual
        self.num_heads = num_heads
        self.pos_dim = dim // 4
        self.pos_proj = nn.Linear(2, self.pos_dim)
        self.pos1 = nn.Sequential(
            nn.LayerNorm(self.pos_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.pos_dim, self.pos_dim),
        )
        self.pos2 = nn.Sequential(
            nn.LayerNorm(self.pos_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.pos_dim, self.pos_dim)
        )
        self.pos3 = nn.Sequential(
            nn.LayerNorm(self.pos_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.pos_dim, self.num_heads)
        )
    def forward(self, biases):
        if self.residual:
            pos = self.pos_proj(biases) # 2Wh-1 * 2Ww-1, heads
            pos = pos + self.pos1(pos)
            pos = pos + self.pos2(pos)
            pos = self.pos3(pos)
        else:
            pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
        return pos

    def flops(self, N):
        flops = N * 2 * self.pos_dim
        flops += N * self.pos_dim * self.pos_dim
        flops += N * self.pos_dim * self.pos_dim
        flops += N * self.pos_dim * self.num_heads
        return flops

class Attention(nn.Module):
    r""" Multi-head self attention module with dynamic position bias.

    Args:
        dim (int): Number of input channels.
        group_size (tuple[int]): The height and width of the group.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, group_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
                 position_bias=True):

        super().__init__()
        self.dim = dim
        self.group_size = group_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.position_bias = position_bias

        if position_bias:
            self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
            
            # generate mother-set
            position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])
            position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])
            biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))  # 2, 2Wh-1, 2W2-1
            biases = biases.flatten(1).transpose(0, 1).float()
            self.register_buffer("biases", biases)

            # get pair-wise relative position index for each token inside the group
            coords_h = torch.arange(self.group_size[0])
            coords_w = torch.arange(self.group_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.group_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.group_size[1] - 1
            relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_groups*B, N, C)
            mask: (0/-inf) mask with shape of (num_groups, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape # [64, 49, 64] for SDA
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) #  q.shape=k.shape=v.shape=[64, 2, 49, 32]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        if self.position_bias:
            pos = self.pos(self.biases) # 2Wh-1 * 2Ww-1, heads # [169, 2]
            # select position bias
            relative_position_bias = pos[self.relative_position_index.view(-1)].view(
                self.group_size[0] * self.group_size[1], self.group_size[0] * self.group_size[1], -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww # [2, 49, 49]
            attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn) # [64, 2, 49, 49]

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # [64, 49, 64]
        x = self.proj(x)            # [64, 49, 64]
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, group_size={self.group_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 group with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        # attn = (q @ k.transpose(-2, -1))
        flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        flops += self.num_heads * N * N * (self.dim // self.num_heads)
        # x = self.proj(x)
        flops += N * self.dim * self.dim
        if self.position_bias:
            flops += self.pos.flops(N)
        return flops


class CrossFormerBlock(nn.Module):
    r""" CrossFormer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        group_size (int): Group size.
        lsda_flag (int): use SDA or LDA, 0 for SDA and 1 for LDA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, group_size=7, lsda_flag=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_patch_size=1):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.group_size = group_size
        self.lsda_flag = lsda_flag
        self.mlp_ratio = mlp_ratio
        self.num_patch_size = num_patch_size
        if min(self.input_resolution) <= self.group_size:
            # if group size is larger than input resolution, we don't partition groups
            self.lsda_flag = 0
            self.group_size = min(self.input_resolution)

        self.norm1 = norm_layer(dim)

        self.attn = Attention(
            dim, group_size=to_2tuple(self.group_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
            position_bias=True)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        attn_mask = None
        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution # [56, 56]
        B, L, C = x.shape            # [1, 3136, 64]
        assert L == H * W, "input feature has wrong size %d, %d, %d" % (L, H, W)

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)       # [1, 56, 56, 64]

        # group embeddings
        G = self.group_size          # 7
        if self.lsda_flag == 0: # 0 for SDA
            x = x.reshape(B, H // G, G, W // G, G, C).permute(0, 1, 3, 2, 4, 5)  # [1, 8, 7, 8, 7, 64] -> [1, 8, 8, 7, 7, 64]
        else: # 1 for LDA
            x = x.reshape(B, G, H // G, G, W // G, C).permute(0, 2, 4, 1, 3, 5)
        x = x.reshape(B * H * W // G**2, G**2, C) # [64, 49, 64] for SDA

        # multi-head self-attention
        x = self.attn(x, mask=self.attn_mask)     # nW*B, G*G, C $ [64, 49, 64]

        # ungroup embeddings
        x = x.reshape(B, H // G, W // G, G, G, C) # [1, 8, 8, 7, 7, 64]
        if self.lsda_flag == 0:
            x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H, W, C)
        else:
            x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, H, W, C)
        x = x.view(B, H * W, C) # [1, 3136, 64]

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))    # [1, 3136, 64]

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"group_size={self.group_size}, lsda_flag={self.lsda_flag}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0
        H, W = self.input_resolution
        # norm1
        flops += self.dim * H * W
        # LSDA
        nW = H * W / self.group_size / self.group_size
        flops += nW * self.attn.flops(self.group_size * self.group_size)
        # mlp
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
        # norm2
        flops += self.dim * H * W
        return flops

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm, patch_size=[2], num_input_patch_size=1):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reductions = nn.ModuleList()
        self.patch_size = patch_size
        self.norm = norm_layer(dim)

        for i, ps in enumerate(patch_size):
            if i == len(patch_size) - 1:
                out_dim = 2 * dim // 2 ** i
            else:
                out_dim = 2 * dim // 2 ** (i + 1)
            stride = 2
            padding = (ps - stride) // 2
            self.reductions.append(nn.Conv2d(dim, out_dim, kernel_size=ps, 
                                                stride=stride, padding=padding))

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = self.norm(x)
        x = x.view(B, H, W, C).permute(0, 3, 1, 2)

        xs = []
        for i in range(len(self.reductions)):
            tmp_x = self.reductions[i](x).flatten(2).transpose(1, 2)
            xs.append(tmp_x)
        x = torch.cat(xs, dim=2)
        return x

    def extra_repr(self) -> str:
        return f"input_resolution={self.input_resolution}, dim={self.dim}"

    def flops(self):
        H, W = self.input_resolution
        flops = H * W * self.dim
        for i, ps in enumerate(self.patch_size):
            if i == len(self.patch_size) - 1:
                out_dim = 2 * self.dim // 2 ** i
            else:
                out_dim = 2 * self.dim // 2 ** (i + 1)
            flops += (H // 2) * (W // 2) * ps * ps * out_dim * self.dim
        return flops


class Stage(nn.Module):
    """ CrossFormer blocks for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        group_size (int): variable G in the paper, one group has GxG embeddings
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, group_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
                 patch_size_end=[4], num_patch_size=None):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList()
        for i in range(depth):
            lsda_flag = 0 if (i % 2 == 0) else 1
            self.blocks.append(CrossFormerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, group_size=group_size,
                                 lsda_flag=lsda_flag,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer,
                                 num_patch_size=num_patch_size))

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer, 
                                         patch_size=patch_size_end, num_input_patch_size=num_patch_size)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x) # [1, 784, 128]
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"

    def flops(self):
        flops = 0
        for blk in self.blocks:
            flops += blk.flops()
        if self.downsample is not None:
            flops += self.downsample.flops()
        return flops


class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: [4].
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        # patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[0] // patch_size[0]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.projs = nn.ModuleList()
        for i, ps in enumerate(patch_size):
            if i == len(patch_size) - 1:
                dim = embed_dim // 2 ** i
            else:
                dim = embed_dim // 2 ** (i + 1)
            stride = patch_size[0]
            padding = (ps - patch_size[0]) // 2
            self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        xs = []
        for i in range(len(self.projs)):
            tx = self.projs[i](x).flatten(2).transpose(1, 2) #[1,32,56,56],[1,16,56,56],[1,8,56,56],[1,8,56,56]
            xs.append(tx)  # B Ph*Pw C #xs[0]=[1, 3136, 32], xs[1]=[1, 3136, 16], xs[2]=[1, 3136, 8], xs[3]=[1, 3136, 8]
        x = torch.cat(xs, dim=2) # [1, 3136, 64]
        if self.norm is not None:
            x = self.norm(x)
        return x

    def flops(self):
        Ho, Wo = self.patches_resolution
        flops = 0
        for i, ps in enumerate(self.patch_size):
            if i == len(self.patch_size) - 1:
                dim = self.embed_dim // 2 ** i
            else:
                dim = self.embed_dim // 2 ** (i + 1)
            flops += Ho * Wo * dim * self.in_chans * (self.patch_size[i] * self.patch_size[i])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops


class CrossFormer(nn.Module):
    r""" CrossFormer
        A PyTorch impl of : `CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention`  -

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each stage.
        num_heads (tuple(int)): Number of attention heads in different layers.
        group_size (int): Group size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=[4], in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 group_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, merge_size=[[2], [2], [2]], **kwargs):
        super().__init__()

        self.num_classes = num_classes       # 1000
        self.num_layers = len(depths)        # 4
        self.embed_dim = embed_dim           # 64
        self.ape = ape                       # False
        self.patch_norm = patch_norm         # True
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) # 512
        self.mlp_ratio = mlp_ratio           # 4

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches                # 3136
        patches_resolution = self.patch_embed.patches_resolution  # [56, 56]
        self.patches_resolution = patches_resolution              # [56, 56]

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()

        num_patch_sizes = [len(patch_size)] + [len(m) for m in merge_size]  # [4, 2, 2, 2]
        for i_layer in range(self.num_layers):
            patch_size_end = merge_size[i_layer] if i_layer < self.num_layers - 1 else None
            num_patch_size = num_patch_sizes[i_layer]
            layer = Stage(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               group_size=group_size[i_layer],
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint,
                               patch_size_end=patch_size_end,
                               num_patch_size=num_patch_size)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)
        # import pdb; pdb.set_trace()

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        import pdb; pdb.set_trace()
        # input_x.shape=[1,3,224,224]
        x = self.patch_embed(x) # x.shape=[1, 3136, 64]
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x) # dropout

        for layer in self.layers:
            x = layer(x)   # [1, 784, 128], [1, 196, 256], [1, 49, 512], [1, 49, 512]

        x = self.norm(x)  # B L C [1, 49, 512]
        x = self.avgpool(x.transpose(1, 2))  # B C 1 # [1, 512, 1]
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def flops(self):
        flops = 0
        flops += self.patch_embed.flops()
        for i, layer in enumerate(self.layers):
            flops += layer.flops()
        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
        flops += self.num_features * self.num_classes
        return flops

你可能感兴趣的:(Transformer,transformer,深度学习,人工智能)