swim transformer之PatchMering实现方法




import torch
import torch.nn as nn
class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        x: B, H*W, C
        print('x',x.shape)#x torch.Size([2, 256, 64])
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)
        print('x11',x)#x torch.Size([2, 256, 64])
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        print('x0',x0)#x torch.Size([2, 256, 64])
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        print('x1', x1)
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

a = torch.randn(2,16,4)
b = PatchMerging((4,4), 4)
c = b(a)


x torch.Size([2, 16, 4])
x11 tensor([[[[-1.0757, -0.6535,  0.1875, -1.3768],
          [-0.7268,  0.1908, -1.5902, -0.4212],
          [-0.5088, -2.6308,  0.7003,  1.7449],
          [ 0.0403,  1.0552, -0.3679, -0.1487]],

         [[-0.4162, -1.4174,  2.2844,  0.4263],
          [ 0.7606,  1.2333, -0.2414,  0.2024],
          [-0.6283, -0.7586, -1.6624,  0.9212],
          [-0.7541,  0.3502, -0.4232, -0.7529]],

         [[-1.2909,  1.6532,  0.6483,  0.3272],
          [-0.7095,  1.4838,  1.7903,  0.9732],
          [ 0.7678, -0.4203, -0.3080, -2.6463],
          [ 0.7799,  1.2861,  0.9010,  0.1704]],

         [[ 0.3832,  0.2464,  1.7505,  0.8058],
          [-0.0818, -0.6448, -0.5167,  0.5433],
          [-0.9108, -0.0747, -0.4282,  0.5872],
          [ 1.8027, -0.8964, -0.7140,  0.4678]]],

        [[[ 0.0983, -1.3372,  1.2565, -1.3958],
          [-0.0959, -0.5359,  2.3124,  0.6544],
          [-1.1128, -0.2913,  0.9412,  0.1104],
          [ 3.1558,  0.5077,  1.0304, -1.0980]],

         [[-0.4048, -0.3352,  0.1244,  1.9302],
          [-0.1532, -0.5788,  0.2044,  1.1670],
          [-1.7893,  0.5874,  0.7560, -0.5011],
          [ 1.1631, -0.6935,  1.7626,  0.4780]],

         [[ 0.0203,  0.0238,  0.0699,  0.4470],
          [ 1.8293,  0.5140, -0.8289,  0.4305],
          [ 0.5267, -0.0716,  0.1068,  0.2828],
          [ 0.0269,  0.2218,  0.2784, -0.4271]],

         [[ 2.0438, -0.2540, -0.6368, -0.5568],
          [-0.8687,  0.9175, -0.5126,  1.7711],
          [ 0.9073,  1.0147,  1.1854, -1.3229],
          [-1.0138,  1.1706, -2.1350, -1.0994]]]])
x0 tensor([[[[-1.0757, -0.6535,  0.1875, -1.3768],
          [-0.5088, -2.6308,  0.7003,  1.7449]],

         [[-1.2909,  1.6532,  0.6483,  0.3272],
          [ 0.7678, -0.4203, -0.3080, -2.6463]]],

        [[[ 0.0983, -1.3372,  1.2565, -1.3958],
          [-1.1128, -0.2913,  0.9412,  0.1104]],

         [[ 0.0203,  0.0238,  0.0699,  0.4470],
          [ 0.5267, -0.0716,  0.1068,  0.2828]]]])
x1 tensor([[[[-0.4162, -1.4174,  2.2844,  0.4263],
          [-0.6283, -0.7586, -1.6624,  0.9212]],

         [[ 0.3832,  0.2464,  1.7505,  0.8058],
          [-0.9108, -0.0747, -0.4282,  0.5872]]],

        [[[-0.4048, -0.3352,  0.1244,  1.9302],
          [-1.7893,  0.5874,  0.7560, -0.5011]],

         [[ 2.0438, -0.2540, -0.6368, -0.5568],
          [ 0.9073,  1.0147,  1.1854, -1.3229]]]])
c torch.Size([2, 4, 8])

Process finished with exit code 0



class PatchMerging(nn.Module):
    def __init__(self, dim, out_dim, norm_layer=nn.BatchNorm2d):
        self.dim = dim
        self.out_dim = out_dim
        self.norm = norm_layer(dim)
        self.reduction = nn.Conv2d(dim, out_dim, 2, 2, 0, bias=False)

    def forward(self, x):
        x = self.norm(x)
        x = self.reduction(x)
        return x

a = torch.randn(2,16,4,4)
b = PatchMerging(16, 64)
c = b(a)


c torch.Size([2, 64, 2, 2])




import torch
import torch.nn as nn
def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

a = torch.randn(2,8,8,16)
c = window_partition(a, 4)
# k =torch.Size([8, 4, 4, 16])
b =window_reverse(c,4,8,8)
# c = b(a)
c torch.Size([8, 4, 4, 16])
b torch.Size([2, 8, 8, 16])




直接分割即把图像直接分成多块。在代码实现上需要使用einops这个库,完成的操作是将(B,C,H,W)的shape调整为(B,(H/P *W/P),P*P*C)。

from einops import rearrange, repeat

from einops.layers.torch import Rearrange

self.to_patch_embedding = nn.Sequential(

Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),

nn.Linear(patch_dim, dim),


这里需要解释的是,一个括号内的两个变量相乘表示的是该维度的长度,因此不要把"h"和"w"理解成图像的宽和高。这里实际上h = H/p1, w = W/p2,代表的是高度上有几块,宽度上有几块。h和w都不需要赋值,代码会自动根据这个表达式计算,b和c也会自动对应到输入数据的B和C。

后面的"b (h w) (p1 p2 c)"表示了图像分块后的shape: (B,(H/P *W/P),P*P*C)





self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C

 在swin transformer中即使用的是这种卷积分块方式。在swin transformer中卷积后没有再加全连接层。

