VIT和Swin Transformer

一 VIT模型

1 代码和模型基础

以timm包为代码基础,VIT模型以vit_base_patch16_224作为模型基础

2 模型结构

2.1 输入的图像 B ∗ 3 ∗ 224 ∗ 224 B*3*224*224 B3224224,第一步patch_embeding,这里一个patch的对应的像素大小是 16 ∗ 16 16*16 1616,也就是对输入图像作conv2d,对应的kernel_size=16,stride=16,以及升维为768,最终得到输出feature为 B ∗ 14 ∗ 14 ∗ 768 B*14*14*768 B1414768,然后转化为 B ∗ 196 ∗ 768 B*196*768 B196768,这里196个patchs其实对应了类似nlp就是196个tokens;
2.2 这里类似nlp,添加了一个起始token,这里用一个可训练的参数torch.nn.Parameter,对应的特征 B ∗ 1 ∗ 768 B*1*768 B1768,然后和上一步生成的196个tokens合并成197个tokens,对应的特征 B ∗ 197 ∗ 768 B*197*768 B197768;然后再加上一个位置编码,可训练的参数torch.nn.Parameter,对应的特征 B ∗ 197 ∗ 768 B*197*768 B197768,相加之后得到后续Block的输入
2.3 这里每个Block对应两块,一个是attention模块,一个是mlp模块;先是attention模块,就是对应的multi-head self attention,输入为 B ∗ 197 ∗ 768 B*197*768 B197768,先经过Layer_norm,在经过torch.nn.Linear升维为 768 ∗ 3 768*3 7683,这里采用heads为12个,然后reshape成 3 ∗ B ∗ 12 ∗ 197 ∗ 64 3*B*12*197*64 3B1219764,然后分别分成q,k,v,每个对应的特征 B ∗ 12 ∗ 197 ∗ 64 B*12*197*64 B1219764VIT和Swin Transformer_第1张图片
得到最终attention之后的特征在通过short_cut,加上初始输入的特征,得到最终的输出 B ∗ 197 ∗ 768 B*197*768 B197768
2.4 对应的mlp模块,这里主要是对输入先通过Layer_norm,在通过Linear进行升维768*4,然后通过gelu激活函数,加dropout,之后在通过Linear降维成768,在通过dropout,然后将该输出通过short_cut,与初始输入相加得到最终输出 B ∗ 197 ∗ 768 B*197*768 B197768
2.5 经过多个上述的Block之后,得到输出 B ∗ 197 ∗ 768 B*197*768 B197768,然后经过Layer_norm,作为最终分类,选取了第一个token作为分类的特征 B ∗ 1 ∗ 768 B*1*768 B1768,然后进入head阶段,通过Linear得到最终1000类分类

二 Swin Transformer

1 代码和模型基础

以timm包为代码基础,Swin Transformer模型以swin_base_patch4_window7_224作为模型基础;该文章解析可以参https://zhuanlan.zhihu.com/p/360513527

2 模型设计思想

2.1 对于transformer从nlp到cv中的应用,主要调整是视觉图像的scale以及高分辨率问题;针对VIT模型,token数量多,计算self-attention,对应的计算量非常大,所以该模型设计window,只计算该window内部的所有token的self attention降低计算量
VIT和Swin Transformer_第2张图片
对于其中的复杂度计算,这里可以参考卷积的flops计算,第一个计算q,k,v的复杂度,其实就是个Linear的升维操作(参照上一部VIT中计算q,k,v方式),对应的flops就是 c ∗ 1 ∗ 1 ∗ 3 c ∗ h ∗ w c*1*1*3c*h*w c113chw

2.2 基于window计算的,虽然减少了计算量,但是这样就造成了每个window的视野局限,只能看到当前window内部的token,看不到全局信息,而且每个window之间信息也不能进行交流;针对这两个问题,作者提出了2个解决方案:
a. 第一个就是类似resnet的层级结构 Hierarchical,每个stage后对 2 ∗ 2 2*2 22组的特征进行merge,同时进行升维(特征空间尺度大小 h ∗ w → h 2 ∗ h w h*w\rightarrow \frac {h}{2}*\frac {h}{w} hw2hwh,特征维度大小 C → 4 C → 2 C C\rightarrow 4C \rightarrow 2C C4C2C),这样每个window感受野就越来越大
VIT和Swin Transformer_第3张图片

b. 就是采用shift window,加强window之间的信息交流
VIT和Swin Transformer_第4张图片
对于shift之后的计算方式可以参考前面的知乎链接,具体的代码实现参考下面代码解析
VIT和Swin Transformer_第5张图片

3 模型结构

3.0 具体代码结构可以参考https://zhuanlan.zhihu.com/p/384514268
3.1 输入的图像 B ∗ 3 ∗ 224 ∗ 224 B*3*224*224 B3224224,第一步patch_embeding,这里一个patch的对应的像素大小是 4 ∗ 4 4*4 44,也就是对输入图像作conv2d,对应的kernel_size=4,stride=4,以及升维为128,最终得到输出feature为 B ∗ 56 ∗ 56 ∗ 128 B*56*56*128 B5656128,然后转化为 B ∗ 3136 ∗ 128 B*3136*128 B3136128,这里3136个patchs其实对应了类似nlp就是3136个tokens;
3.2 这里没有用到position embeding,因为这里作者采用了relative position bias,发现添加postion embeding对效果有一点损失,所以去掉了这一块,相应的对比试验
VIT和Swin Transformer_第6张图片
3.3 从上一步输入 B ∗ 3136 ∗ 128 B*3136*128 B3136128,进入Stage1中的SwinTransformerBlock,这里是2个Block交替进行,分别是W-MSA(Window based Self Attention)和SW-MSA(Shift Window based Self Attention)
VIT和Swin Transformer_第7张图片
3.3.1 第一个W-MSA和VIT中的MSA模块基本类似,只是这里面添加了个relative position bias,如下公式VIT和Swin Transformer_第8张图片
代码如下:

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)

对于relative position bias计算方式,对于一个window内部(window_size是M),大小是 M ∗ M M*M MM,先去计算每个window内每个patch的相对坐标位置;这里生成的relative_coords: 2 ∗ M 2 ∗ M 2 2*M^2*M^2 2M2M2,2是分别代表y坐标差(对应行)和x坐标差(对应列),会发现每个维度的坐标差的范围是 [ − M + 1 , M − 1 ] [-M+1, M-1] [M+1,M1],这里将坐标差转化为正数,所以对于每一个值加上 M − 1 M-1 M1,这样对应的坐标差范围是 [ 0 , 2 M − 2 ] [0, 2M-2] [0,2M2],刚好是 2 M − 1 2M-1 2M1个数,同时对y坐标乘以 2 M − 1 2M-1 2M1 ,这样在对x和y坐标差相加之后的范围是 [ 0 , ( 2 M − 2 ) ∗ ( 2 M − 1 ) + ( 2 M − 2 ) ] [0, (2M-2)*(2M-1)+(2M-2)] [0,(2M2)(2M1)+(2M2)],一共是 ( 2 M − 1 ) ∗ ( 2 M − 1 ) (2M-1)*(2M-1) (2M1)(2M1)个数,刚好对应生成的relative_position_bias_table特征大小是 ( 2 M − 1 ) ∗ ( 2 M − 1 ) (2M-1)*(2M-1) (2M1)(2M1),可以在这个特征里面找到所有相对位置relative_position_index的值;这里为什么要乘以 2 M − 1 2M-1 2M1,应该是个trick,个人猜测,第一个是如果乘以的数太小,会导致圆点坐标的patch与其他patch的坐标差有重复,第二个是乘以 2 M − 1 2M-1 2M1刚好可以使生成的特征大小是 ( 2 M − 1 ) ∗ ( 2 M − 1 ) (2M-1)*(2M-1) (2M1)(2M1),当然乘以2M应该也可以,对应生成的特征大小是 ( 2 M − 2 ) ∗ ( 2 M − 1 ) + 2 M + 1 (2M-2)*(2M-1)+2M+1 (2M2)(2M1)+2M+1,好像也能满足,具体原因还是不太明白;其中生成的特征relative_position_bias_table,是均值为0,标准差0.02的一组向量;最终的计算,就是在q和k计算attention矩阵之后,在加上根据relative_position_index的位置查找对应在relative_position_bias_table中的值,组成了最终的relative position bias,得到最终的attention矩阵;剩余的步骤与VIT中的一致,swin transformer主要增加了一项relative position bias 替换了VIT原有的position embeding
3.4 进入SW-MSA模块,这里主要是增加了一个shift操作,其余与W-MSA基本操作一样,如图:
VIT和Swin Transformer_第9张图片
在shift之后,从原来的 2 ∗ 2 2*2 22变成了 3 ∗ 3 3*3 33个window,为了批量计算,一般想法是padding成每个window同样大小,但是这样就增加了window数量,增加计算量,这里就把a,c,b三块进行移动到右下角,组成了新的 2 ∗ 2 2*2 22window,但是这样除了左上角第一个的window是完整的不需要改变,其余三个是组成的混合window,不需要使用window内部所有patch的attention,只需要以前划分的 3 ∗ 3 3*3 33对应window内的patch的attention,举例右下角的window是有A,B的下半部分,C的右半部分,及其他4部分组成,这里只需要计算A内部patch之间的attention,不需要计算A与B或者C的attention,因为A是移动过来的,计算类似图像上下边缘或者左右边缘的关系作用不大,所以加上了一个mask,去选取需要的attention,之后在将移动之后的window在移回去,达到批量快速计算的效果;这里第一步是进行shift操作,主要是通过torch.roll实现;第二步是计算相应的mask,可以参考https://zhuanlan.zhihu.com/p/360513527,如下图VIT和Swin Transformer_第10张图片
这里计算代码,如下:

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

假设输入的是 1 ∗ 56 ∗ 56 ∗ 1 1*56*56*1 156561的mask,window_size为7,shift_size为3,产生h_slices和w_slices都是按照三个(0,-7),(-7,-3),(-3,None)进行划分,一共组成了9个块,并分别按照0-8进行标记,然后再进行相减,如果是结果是0,就保留(这表示是同一块,需要计算之间的attention),其余的赋值为-100,不保留;
3.5 在每个stage之后,会先进行Patch Merge操作,对特征进行下采样然后升维的过程;假设是输入特征大小 B ∗ 56 ∗ 56 ∗ 128 B*56*56*128 B5656128,然后沿着x和y方向间隔一个取特征,分成了个4个 B ∗ 28 ∗ 28 ∗ 128 B*28*28*128 B2828128的特征,然后cat到一起,得到 B ∗ 28 ∗ 28 ∗ 512 B*28*28*512 B2828512的特征,之后reshape,加layer norm 在加一个linear 进行降维成256,其实整个Pathc Merge过程就是减小了特征的空间大小,同时增大维度
3.6 经过4个stage(每个stage对应的block数量[2,2,18,2])之后,得到特征 B ∗ 49 ∗ 1024 B*49*1024 B491024,经过layernorm,在经过一个平均池化,得到 B ∗ 1024 B*1024 B1024,然后后面是head阶段,跟着一个linear 分类成1000类,得到最终的结果

你可能感兴趣的:(VIT和Swin Transformer)