SwinTransformer 代码复现 - 模型搭建

参考swin transformer源码,我们修改了:

  1. 添加了DropPath策略
  2. 每一个stage的输出添加了norm层
  3. 每一个PatchMerge层添加了norm层
  4. 源码里每一个block, 每一个head都使用不共享的pos bias, 我们这里使用的是共享的
# --*-- coding:utf-8 --*--
import torch 
import torch.nn as nn 
from torch.nn import functional as F 

import torch.utils.checkpoint as cp
from mmcv.cnn import (constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from timm.models.layers import DropPath, trunc_normal_
from mmdet.utils import get_root_logger
from ..builder import BACKBONES

class Mlp(nn.Module):
    expasion = 4
    def __init__(self, in_feature, hidden_feature=None, out_feature=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_feature = out_feature or in_feature 
        hidden_feature = hidden_feature or in_feature * self.expasion 
        self.fc1 = nn.Conv2d(in_feature, hidden_feature, 1, 1, 0)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_feature, out_feature, 1, 1, 0)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        return self.drop(self.fc2(self.drop(self.act(self.fc1(x)))))

class PatchEmbedding(nn.Module):
    def __init__(self, in_feature, out_feature, kernel_size=4, norm_layer=nn.LayerNorm, drop=0.):
        super().__init__()
        self.patch_size = kernel_size
        self.fc = nn.Conv2d(in_feature, out_feature, kernel_size=kernel_size, stride=kernel_size, padding=0)
        self.drop = nn.Dropout(drop)
        if norm_layer is not None:
            self.norm = norm_layer(out_feature)
        else:
            self.norm = None

    def forward(self, x):
        _, _, H, W = x.size()
        if W % self.patch_size != 0:
            x = F.pad(x, (0, self.patch_size - W % self.patch_size))
        if H % self.patch_size != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size - H % self.patch_size))
        x = self.drop(self.fc(x))
        if self.norm is not None:
            x = self.norm(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
        return x 

class PatchMerging(nn.Module):  
    def __init__(self, in_feature, out_feature, kernel_size=2, norm_layer=nn.LayerNorm, drop=0.):
        super().__init__()
        self.fc = nn.Linear(in_feature* kernel_size**2, out_feature, bias=False)
        self.kernel_size = kernel_size
        self.norm = norm_layer(in_feature* kernel_size**2)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        B, C, H, W = x.size()
        x = x.view(B, C, H//self.kernel_size, self.kernel_size, W//self.kernel_size, self.kernel_size).permute(0, 2, 4, 1, 3, 5).contiguous()
        x = self.drop(self.fc(self.norm(torch.flatten(x, 3))))
        x = x.permute(0, 3, 1, 2).contiguous()
        return x

class WMSA(nn.Module):
    def __init__(self, dim, head_dim=32, M=7, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__() 

        self.dim = dim 
        self.n_heads = dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5 
        self.head_dim = head_dim
        self.M = M 
        self.q = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, bias=qkv_bias)
        self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, stride=1, padding=0, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop )
        self.proj = nn.Conv2d(dim, dim, 1, 1, 0)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, pos_bias, shift, masks):
        """
        :param x:   tensor, BCHW
        :param pos_bias:    tensor, M^2 x M^2
        :param shift:       int, 0 or 1
        :param masks:       dict{"top":tensor(M^2 x M^2), "left":tensor, "topleft":tensor}
        :return:    tensor, BCHW
        """
        B, C, H, W = x.size()
        wn_h, wn_w = H//self.M, W//self.M

        q = self.q(x).view(B, self.n_heads, self.head_dim, wn_h, self.M, wn_w, self.M).permute(0, 3, 5, 1,2,4,6).contiguous()
        q = q.view(B, wn_h, wn_w, self.n_heads, self.head_dim, -1)  # B x wh x ww x n_head x head_dim x M^2
        kv = self.kv(x).view(B, 2, self.n_heads, self.head_dim, wn_h, self.M, wn_w, self.M).permute(1, 0, 4, 6, 2, 3, 5, 7).contiguous()  # 2 x B x wh x wn x n_head x head_dim x M^2
        kv = kv.view(2, B, wn_h, wn_w, self.n_heads, self.head_dim, -1) 
        k, v = kv[0], kv[1]
        attn = ((q.transpose(-2, -1).contiguous())@k) * self.scale + pos_bias.expand(1, 1, 1, 1, self.M**2, self.M**2)  # B x wh x wn x n_head x M^2 x M^2

        if shift==1:
            attn[:, :-1, -1] += masks["left"].expand(1, 1, 1, self.M**2, self.M**2).to(x.device)
            attn[:, -1, :-1] += masks["top"].expand(1, 1, 1, self.M**2, self.M**2).to(x.device)
            attn[:, -1, -1] += masks["topleft"].expand(1, 1, self.M**2, self.M**2).to(x.device)
            
        attn = self.attn_drop(F.softmax(attn, dim=-1))
        x = (v @ (attn.transpose(-2, -1).contiguous())).view(B, wn_h, wn_w, -1, self.M, self.M)  # B x wh x wn x C x M x M
        x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(B, C, H, W)
        x = self.proj_drop(self.proj(x))
        return x

class SwinTransformerBlock(nn.Module):
    def __init__(self, in_feature, head_dim=32, M=7, shift=0, norm_layer=nn.LayerNorm, drop_path=0.):
        super().__init__()
        self.norm1 = norm_layer(in_feature)
        self.norm2 = norm_layer(in_feature)
        self.multi_attn = WMSA(in_feature, head_dim, M)
        self.mlp = Mlp(in_feature)
        self.M = M 
        self.shift = shift 
        self.drop_path = DropPath(drop_path) if drop_path>0. else nn.Identity()

    def forward(self, x, pos_bias, masks):
        b, c, oh, ow = x.size()
        # cyclic shifted window
        shift_stride = self.M//2
                
        if oh%self.M == 0: padding_bottom = 0
        else:             padding_bottom = (oh//self.M+1)*self.M - oh 
        if ow%self.M == 0: padding_right = 0
        else:             padding_right = (ow//self.M+1)*self.M - ow 
        x = F.pad(x, (0, padding_right, 0, padding_bottom), 'constant', 0)
        h, w = x.size(-2), x.size(-1)

        if self.shift == 1:  # top-left
            x = x.roll((-shift_stride, -shift_stride), (-1, -2))

        norm1 = self.norm1(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()  # B x C x H x W
        z1 = self.multi_attn(norm1, pos_bias, self.shift, masks)
        z2 = x + self.drop_path(z1) 
        norm2 = self.norm2(z2.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous()
        z2 = z2 + self.drop_path(self.mlp(norm2))

        if self.shift ==1:
            z2 = torch.roll(z2, (shift_stride, shift_stride), (-1, -2))
        
        return z2[..., :oh, :ow]

class Stage(nn.Module):
    def __init__(self, in_feature, out_feature, num_layers, patch_norm=True, patch_merge=PatchEmbedding, M=7, head_dim=32, stride=2, drop_path=(0.2, 0.2)):
        super().__init__()
        self.downsample = patch_merge(in_feature, out_feature, kernel_size=stride, norm_layer=nn.LayerNorm if patch_norm else None)
        self.blocks = nn.ModuleList()
        for k in range(num_layers//2):
            self.blocks.append(
                SwinTransformerBlock(out_feature, head_dim, M, 0, drop_path=drop_path[k*2])
            )
            self.blocks.append(
                SwinTransformerBlock(out_feature, head_dim, M, 1, drop_path=drop_path[2*k+1])
            )

    def forward(self, x, pos_bias, masks):
        x = self.downsample(x)
        for m in self.blocks:
            x = m(x, pos_bias, masks)
        return x

class PosBias(nn.Module):
    def __init__(self, M):
        super().__init__()
        self.M = M 
        self.emb_dict = nn.Embedding((2*M-1)**2, 1)
    
    def forward(self, device):
        x, y = torch.meshgrid(torch.arange(self.M), torch.arange(self.M))
        indices = torch.stack((x.flatten(), y.flatten()), dim=1)
        indices = indices.unsqueeze(1) - indices.unsqueeze(0)
        indices = (indices[..., 0] +self.M-1) * (2*self.M-1) + (indices[..., 1] + self.M-1)
        indices = indices.long().to(device)
        return self.emb_dict(indices).squeeze(-1)  # M**2 x M**2

@BACKBONES.register_module()
class SwinTransformer(nn.Module):
    """SwinTransformer backbone 

    Args:
        model_type (str): type of the swin transformer type, from {'T', 'S', 'B', 'L'}
        out_indices (Sequence [int]): Output from which stages.
        M (int): size of the window.
        head_dim (int): dim of each head in MSA.
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only.

    Example:
        >>> from mmdet.models import SwinTransformer
        >>> import torch
        >>> self = SwinTransformer(model_type="T")
        >>> self.eval()
        >>> inputs = torch.rand(1, 3, 32, 32)
        >>> level_outputs = self.forward(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        (1, 96, 8, 8)
        (1, 192, 4, 4)
        (1, 384, 2, 2)
        (1, 768, 1, 1)
    """

    arch_settings = {
        "T": (96, (2, 2, 6, 2)),
        "S": (96, (2, 2, 18, 2)),
        "B": (128, (2, 2, 18, 2)),
        "L": (192, (2, 2, 18, 2))
    }

    def __init__(self, model_type, out_indices=(0,1,2,3), M=7, head_dim=32, patch_norm=True, frozen_stages=-1, drop_path_rate=0.2):
        super().__init__()
        init_feature, layers = self.arch_settings[model_type]
        self.frozen_stages = frozen_stages
        self.out_indices = out_indices

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]  # stochastic depth decay rule

        self.pos_bias = PosBias(M)
        self.backbone = nn.ModuleList()
        self.backbone.append(Stage(3, init_feature, layers[0], patch_norm, PatchEmbedding, M, head_dim, stride=4, drop_path=dpr[:layers[0]]))
        out_feature_dims = [init_feature]
        for i, v in enumerate(layers[1:]):
            self.backbone.append(Stage(init_feature, 2*init_feature, v, True, PatchMerging, M, head_dim, stride=2, 
                                    drop_path=dpr[sum(layers[:i+1]):sum(layers[:i+2])]))
            init_feature *= 2
            out_feature_dims.append(init_feature)
        
        #add a norm layer for each output
        for k in out_indices:
            self.add_module(f'norm_stage{k}', nn.LayerNorm(out_feature_dims[k]))
            
        self.M = M 
        self.masks = {
            "top": self.create_mask("top"),
            "left": self.create_mask("left"),
            "topleft": self.create_mask("topleft")
        }

        self._freeze_stages()

    def create_mask(self, d):
        """ get the mask according to the direction
        :param d:   str, (top, left, topleft)
        :return :   tensor, M^2 x M^2
        """
        base = torch.ones(self.M, self.M)
        mask = torch.zeros(self.M**2, self.M**2)
        stride = self.M //2
        s_stride = self.M - stride
        if d == 'top':
            base[:s_stride] = 0
            base = base.flatten()
            mask[base==0] = base 
            mask[base==1] = 1 - base 
        elif d == "left":
            base[:, :s_stride] = 0
            base = base.flatten()
            mask[base==0] = base 
            mask[base==1] = 1 - base 
        elif d == "topleft":
            base[:s_stride, :s_stride] = 0
            base[s_stride:, :s_stride] = 2
            base[s_stride:, s_stride:]=3
            base = base.flatten()
            mask[base==0] = (~(base ==0) ).float()
            mask[base==1] = (~(base ==1)).float()
            mask[base==2] = (~(base ==2)).float()
            mask[base==3] = (~(base ==3)).float()
        mask[mask>0] = float('-inf')
        return mask 

    def _freeze_stages(self):
        if self.frozen_stages>=0:
            self.pos_bias.eval()
            for param in self.pos_bias.parameters():
                param.requires_grad = False 
        for i in range(self.frozen_stages):
            for param in self.backbone[i].parameters():
                param.requires_grad = False

    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
            elif isinstance(m, nn.Conv2d):
                kaiming_init(m)
                nn.init.constant_(m.bias, 0)

        if isinstance(pretrained, str):
            self.apply(_init_weights)
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            self.apply(_init_weights)
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        outs = []
        posb = self.pos_bias(x.device)
        for i, m in enumerate(self.backbone):
            x = m(x, posb, self.masks)
            if i in self.out_indices:
                norm_layer = getattr(self, f'norm_stage{i}')
                outs.append(norm_layer(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous())
        return tuple(outs)

    def train(self, mode=True):
        """Convert the model into training mode while keep normalization layer
        freezed."""
        super(SwinTransformer, self).train(mode)
        self._freeze_stages()
      ```

你可能感兴趣的:(SwinTransformer 代码复现 - 模型搭建)