swim transformer之PatchMering实现方法

PatchMering的作用:其实就是下采样的作用

实现方法有两种:

第一种:

import torch
import torch.nn as nn
class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        print('x',x.shape)#x torch.Size([2, 256, 64])
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)
        print('x11',x)#x torch.Size([2, 256, 64])
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        print('x0',x0)#x torch.Size([2, 256, 64])
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        print('x1', x1)
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

a = torch.randn(2,16,4)
b = PatchMerging((4,4), 4)
c = b(a)
print('c',c.shape)

输出:

x torch.Size([2, 16, 4])
x11 tensor([[[[-1.0757, -0.6535,  0.1875, -1.3768],
          [-0.7268,  0.1908, -1.5902, -0.4212],
          [-0.5088, -2.6308,  0.7003,  1.7449],
          [ 0.0403,  1.0552, -0.3679, -0.1487]],

         [[-0.4162, -1.4174,  2.2844,  0.4263],
          [ 0.7606,  1.2333, -0.2414,  0.2024],
          [-0.6283, -0.7586, -1.6624,  0.9212],
          [-0.7541,  0.3502, -0.4232, -0.7529]],

         [[-1.2909,  1.6532,  0.6483,  0.3272],
          [-0.7095,  1.4838,  1.7903,  0.9732],
          [ 0.7678, -0.4203, -0.3080, -2.6463],
          [ 0.7799,  1.2861,  0.9010,  0.1704]],

         [[ 0.3832,  0.2464,  1.7505,  0.8058],
          [-0.0818, -0.6448, -0.5167,  0.5433],
          [-0.9108, -0.0747, -0.4282,  0.5872],
          [ 1.8027, -0.8964, -0.7140,  0.4678]]],


        [[[ 0.0983, -1.3372,  1.2565, -1.3958],
          [-0.0959, -0.5359,  2.3124,  0.6544],
          [-1.1128, -0.2913,  0.9412,  0.1104],
          [ 3.1558,  0.5077,  1.0304, -1.0980]],

         [[-0.4048, -0.3352,  0.1244,  1.9302],
          [-0.1532, -0.5788,  0.2044,  1.1670],
          [-1.7893,  0.5874,  0.7560, -0.5011],
          [ 1.1631, -0.6935,  1.7626,  0.4780]],

         [[ 0.0203,  0.0238,  0.0699,  0.4470],
          [ 1.8293,  0.5140, -0.8289,  0.4305],
          [ 0.5267, -0.0716,  0.1068,  0.2828],
          [ 0.0269,  0.2218,  0.2784, -0.4271]],

         [[ 2.0438, -0.2540, -0.6368, -0.5568],
          [-0.8687,  0.9175, -0.5126,  1.7711],
          [ 0.9073,  1.0147,  1.1854, -1.3229],
          [-1.0138,  1.1706, -2.1350, -1.0994]]]])
x0 tensor([[[[-1.0757, -0.6535,  0.1875, -1.3768],
          [-0.5088, -2.6308,  0.7003,  1.7449]],

         [[-1.2909,  1.6532,  0.6483,  0.3272],
          [ 0.7678, -0.4203, -0.3080, -2.6463]]],


        [[[ 0.0983, -1.3372,  1.2565, -1.3958],
          [-1.1128, -0.2913,  0.9412,  0.1104]],

         [[ 0.0203,  0.0238,  0.0699,  0.4470],
          [ 0.5267, -0.0716,  0.1068,  0.2828]]]])
x1 tensor([[[[-0.4162, -1.4174,  2.2844,  0.4263],
          [-0.6283, -0.7586, -1.6624,  0.9212]],

         [[ 0.3832,  0.2464,  1.7505,  0.8058],
          [-0.9108, -0.0747, -0.4282,  0.5872]]],


        [[[-0.4048, -0.3352,  0.1244,  1.9302],
          [-1.7893,  0.5874,  0.7560, -0.5011]],

         [[ 2.0438, -0.2540, -0.6368, -0.5568],
          [ 0.9073,  1.0147,  1.1854, -1.3229]]]])
c torch.Size([2, 4, 8])

Process finished with exit code 0

这里面用到花式python切片

另一种方法实现

class PatchMerging(nn.Module):
    def __init__(self, dim, out_dim, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim
        self.norm = norm_layer(dim)
        self.reduction = nn.Conv2d(dim, out_dim, 2, 2, 0, bias=False)

    def forward(self, x):
        x = self.norm(x)
        x = self.reduction(x)
        return x

a = torch.randn(2,16,4,4)
b = PatchMerging(16, 64)
c = b(a)
print('c',c.shape)

输出:

c torch.Size([2, 64, 2, 2])

关于分块可以采用两种:一种是直接方法;一种是通过卷积核和步长都为patch大小的卷积来分割。

1.window_partition的作用:是将特征图谱分割成num_patch,每一个num_patch的尺寸大小是patch_size

2.window_reverse的作用:是将patch_size恢复成输入的特征图谱形式

import torch
import torch.nn as nn
#是将特征图谱分割成num_patch,每一个num_patch的尺寸大小是patch_size
def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

#是将patch_size恢复成输入的特征图谱形式
def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

a = torch.randn(2,8,8,16)
c = window_partition(a, 4)
# k =torch.Size([8, 4, 4, 16])
b =window_reverse(c,4,8,8)
# c = b(a)
print('c',c.shape)
print('b',b.shape)
c torch.Size([8, 4, 4, 16])
b torch.Size([2, 8, 8, 16])

 从零搭建Pytorch模型教程(三)搭建Transformer网络_CV技术指南(公众号)的博客-CSDN博客_pytorch搭建transformer

1.我们都知道transformer是在nlp中应用的,为此在nlp中的首先输入形式就是序列的形势。到了CV领域里面,图像是2维的,为此,我们需要将图像变成一维的,如何实现通常采用的是如下两种方案。一种是直接分割,一种是通过卷积核和步长都为patch大小的卷积来分割

1)直接分割

直接分割即把图像直接分成多块。在代码实现上需要使用einops这个库,完成的操作是将(B,C,H,W)的shape调整为(B,(H/P *W/P),P*P*C)。

from einops import rearrange, repeat

from einops.layers.torch import Rearrange


self.to_patch_embedding = nn.Sequential(

Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),

nn.Linear(patch_dim, dim),

)

这里需要解释的是,一个括号内的两个变量相乘表示的是该维度的长度,因此不要把"h"和"w"理解成图像的宽和高。这里实际上h = H/p1, w = W/p2,代表的是高度上有几块,宽度上有几块。h和w都不需要赋值,代码会自动根据这个表达式计算,b和c也会自动对应到输入数据的B和C。

后面的"b (h w) (p1 p2 c)"表示了图像分块后的shape: (B,(H/P *W/P),P*P*C)

这种方式在分块后还需要通过一层全连接层将分块的向量映射为tokens。

在ViT中使用的就是这种直接分块方式。
 

2)卷积分割

卷积分割比较容易理解,使用卷积核和步长都为patch大小的卷积对图像卷积一次就可以了。

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)


x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C


 在swin transformer中即使用的是这种卷积分块方式。在swin transformer中卷积后没有再加全连接层。

Swin-Transformer(原理 + 代码)详解_☞源仔的博客-CSDN博客_swin transformer代码
图解Swin Transformer - 知乎

SwinTransformer细节及代码实现(pytorch版本)_小小小~的博客-CSDN博客_swin transformer 代码

你可能感兴趣的:(python,深度学习,人工智能,pytorch)