Swin Transformer——细节详解

Swin Transformer——细节详解

划分窗口与合并窗口

举例:
默认情况下window_size=7
划分窗口:将输入数据shape=[4, 224, 224, 196]的数据reshape成[4*num_windows, 7, 7, 196]
合并窗口:将输入数据shape=[4*num_windows, 7, 7, 196]的数据reshape成[4, 224, 224, 196]

def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

你可能感兴趣的:(transformer,深度学习,人工智能)