【零基础讲论文源码】Swin-Transformer源代码阅读

目前这个系列会开两个方向, cv transformer 和OCR方向。

Transformer方向

  • swin-transformer解读【链接】
  • CVT 解读【链接】
  • 待续

OCR方向

  • DBnet解读【链接】(正在制作中。。。)
  • PP_OCR【链接】(待续)
  • 待续

整体介绍

Swin-transformer是微软 CVPR2021今年最近一篇非常棒的论文。
Github【源代码地址】
原文地址【地址】
先上个结构图:
【零基础讲论文源码】Swin-Transformer源代码阅读_第1张图片
(为方便阅读,代码进行简化)

SwinTransformer

: 主代码

#整体结构中,通过PatchEmbed()分割出图像块,再经过相应层数的BasicLayer()。
class SwinTransformer(nn.Module):
    def __init__():
        super().__init__()

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed()

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer()
            self.layers.append(layer)

        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        #用于输出的分类维度,可以根据自己的需要更改
        
    def forward_features(self, x):
        x = self.patch_embed(x)
        # b h w c -> b (h/4)*(w/4) 16*c
        x = self.pos_drop(x)
        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        #以下用于进行分类等任务,可以根据需要进行调整
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


PatchEmbed

:分割图像信息

class PatchEmbed(nn.Module):
    def __init__():
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        #flatten(2)等于从2维度开始进行展平操作,x的维度为b c h w ,
        #设patch_size为4,则结果为 b (h/4)*(w/4) 16*c
        
        if self.norm is not None:
            x = self.norm(x)
        return x

论文这里是说用nn.unfold函数,卷而不积,self.proj的卷积+flatten(start_dim)模拟一个unfold的操作(确实卷积肯定是比较优的结果,直接unfold,反而不平滑),并通过patch_size的大小,对图像进行缩小,并分割,传出的特征为b (h/4)(w/4) 16c

#flatten举例
>>> a=torch.randn(1,2,3,4)
>>> a.flatten(2).shape
torch.Size([1, 2, 12])

BasicLayer

作为核心的stage层(文中为4层)

class BasicLayer(nn.Module):

    def __init__(self, ):

        super().__init__()
        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock()
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample()
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
              x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
            #x:  b h w c -> b h/2*w/2 2c
        print('basic_layer:{}'.format(x.shape))
        return x

由图可知,basic_layer分为两个部分,patch_merging和swin_trans_block结构,对应代码里的downsample和 SwinTransformerBlock()

downsample在代码里为需传入参数,在swin_trans里,basic_layer配置如下:

BasicLayer(
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               window_size=window_size,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,)
                               #如果不是在最后一层,就进行patchmerging操作

【零基础讲论文源码】Swin-Transformer源代码阅读_第2张图片
之所以放这样结构的图,是因为逻辑判断中,在basic_layer中,会先通过block再接patch_merging层,然后在最后一层中不使用patch_merging。

PatchMerging

每个阶段,分割模块

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)
		#通过reduction,x: b h/2*w/2 2c,即实现了降维的操作,并降低了宽高的特征大小
        return x

这代码有点奇特。通过步长+concat的方式,将b h w c -> b h/2w/2 4c,并通过一个全连接进行降维4d->2d,即最后,维度为 b h/2*w/2 2c

SwinTransformerBlock

核心的block代码,包含windows-attention和shfit-windows-attention。
这块代码暂时先不更了,博主先跑跑看,核心代码贴下面(如果你是希望替换模块,这里也就不需要继续看了,通过上面的信息,基本可以替换需要的模块)

class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """
    def __init__():
        super().__init__()

        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention()

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp()

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

WindowAttention

包含shifted 和 non-shifted window attention结构

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

你可能感兴趣的:(论文代码研读,CVTransformer,pytorch,深度学习,Transformer)