《Shunted Transformer》-- 代码笔记

论文地址:https://arxiv.org/pdf/2111.15193.pdf

代码地址:https://github.com/OliverRensu/Shunted-Transformer

        模型是通过 SSA.py 文件中利用 @register_model 方法定义:

        具体流程如下:

        step1: model = ShuntedTransformer()

@register_model
def shunted_t(pretrained=False, **kwargs):
    model = ShuntedTransformer(
        patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[1, 2, 4, 1], sr_ratios=[8, 4, 2, 1], num_conv=0,
        **kwargs)
    model.default_cfg = _cfg()

    return model

        step2: Class ShuntedTransformer()

class ShuntedTransformer(nn.Module):
    """省略了一些简单的定义"""
     def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, num_conv=0):
    """..."""

    def forward_features(self, x):
        B = x.shape[0]

        for i in range(self.num_stages):    #    论文里的stage一共是4个。
            patch_embed = getattr(self, f"patch_embed{i + 1}")    #    patch_embed都是通过Conv2d实现;
            block = getattr(self, f"block{i + 1}")    # block = nn.ModuleList([Block() for j in range(depth[I]) ])
            norm = getattr(self, f"norm{i + 1}")    # nn.LayerNorm
            x, H, W = patch_embed(x)
            for blk in block:
                x = blk(x, H, W)
            x = norm(x)
            if i != self.num_stages - 1:
                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        return x.mean(dim=1)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)    # self.head = nn.Linear(embed_dim, num_class)

        return x

        step3: Class Blcok()        

class Block(nn.Module):

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))    # self.attn = Attention()
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

        return x

        step4: Class Attention():重要!!

        论文中的公式如下:\begin{aligned} Q_{i} &=X W_{i}^{Q} \\ K_{i}, V_{i} &=M T A\left(X, r_{i}\right) W_{i}^{K}, M T A\left(X, r_{i}\right) W_{i}^{V}, \\ V_{i} &=V_{i}+\operatorname{LE}\left(V_{i}\right) \end{aligned}{\color{Blue} }

        其中:MTA()表示token聚合,LE是分组卷积;

        《Shunted Transformer》-- 代码笔记_第1张图片​​​​​​​

class Attention(nn.Module):
    def forward(self, x, H, W):
        B, N, C = x.shape   # x = (b,n,c)
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        # self.q() = nn.Linear(dim,dim)
        # q = (b, heads=8, n, c/heads);
        if self.sr_ratio > 1:
                x_ = x.permute(0, 2, 1).reshape(B, C, H, W) # x_ = (b,c,h,w)
                # 为什么要做两个x_1,x_2 ?
                # sr1:Conv2d(当r=8时,k=8,s=8)
                # sr2:Conv2d(当r=8时,k=4,s=4)
                x_1 = self.act(self.norm1(self.sr1(x_).reshape(B, C, -1).permute(0, 2, 1)))
                x_2 = self.act(self.norm2(self.sr2(x_).reshape(B, C, -1).permute(0, 2, 1)))
                # x_1 = (b, hw/8*8, c);
                # x_2 = (b, hw/4*4, c);

                kv1 = self.kv1(x_1).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4)
                kv2 = self.kv2(x_2).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4)
                #   self.kv1 = nn.Linear(dim,dim)
                #   kv1 = (2, b, heads/2, hw/8*8, c/heads)  第1,3,5项乘起来是c,其实就是(b,c,hw/r*r),然后把通道数分开;
                #   kv2 = (2, b, heads/2, hw/4*4, c/heads)
                k1, v1 = kv1[0], kv1[1] # ( b, heads/2, hw/8*8, c/heads)
                k2, v2 = kv2[0], kv2[1] # ( b, heads/2, hw/4*4, c/heads)
                # @表示矩阵乘法;
                attn1 = (q[:, :self.num_heads//2] @ k1.transpose(-2, -1)) * self.scale
                # attn1 = q:(b, :heads/2, n, c/heads) @ kv1 = (b, heads/2, c/heads, hw/r*r) = (n, hw/r*r)
                attn1 = attn1.softmax(dim=-1)
                attn1 = self.attn_drop(attn1)

                v1 = v1 + self.local_conv1(v1.transpose(1, 2).reshape(B, -1, C//2).
                                        transpose(1, 2).view(B,C//2, H//self.sr_ratio, W//self.sr_ratio)).\
                    view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2)
                    # v1 =( b, heads/2, hw/8*8, c/heads)  
                    # -> self.local_conv1(分组卷积: 将每个通道分成一个组,尺寸不变)= ( b, c/2, h/8, w/8)
                    # -> ( b, heads/2, hw/8*8, c/heads)
                x1 = (attn1 @ v1).transpose(1, 2).reshape(B, N, C//2)   # x1 = (b,n,c/2) 这一步将hw消掉;


                attn2 = (q[:, self.num_heads // 2:] @ k2.transpose(-2, -1)) * self.scale
                attn2 = attn2.softmax(dim=-1)
                attn2 = self.attn_drop(attn2)
                v2 = v2 + self.local_conv2(v2.transpose(1, 2).reshape(B, -1, C//2).
                                        transpose(1, 2).view(B, C//2, H*2//self.sr_ratio, W*2//self.sr_ratio)).\
                    view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2)
                x2 = (attn2 @ v2).transpose(1, 2).reshape(B, N, C//2)

                x = torch.cat([x1,x2], dim=-1)
        else:

你可能感兴趣的:(#,语义分割,transformer,深度学习,人工智能)