文章提出了一种新的ViT(Vision Transformer)作为计算机视觉任务的通用主干。而为了解决图像与NLP在数据规模和分辨率上存在的差异,设计了一种类似于ResNet等传统卷积网络类似的分层(Stage)结构,对于不同尺度的目标更具灵活性。同时引入滑窗(Shifted Window)来进行非重叠局部窗口的自注意力计算;滑窗也允许跨窗口patch的连接,在降低计算量的同时实现不同窗口区域内容的交互。
论文名称: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
论文地址: ICCV 2021 open access
代码地址: https://github.com/microsoft/Swin-Transformer
代码源自@太阳花的小绿豆,特别感谢!
导师代码地址: GitHub
导师代码讲解: https://www.bilibili.com/video/BV1yg411K7Yc
网络的总体结构如Fig.1所示,首先通过tokenization方法(由Patch Paritition
和Linear Embedding
组成)将输入图像生成token;而后通过四个不同的stage来构建尺度不同的特征图针对下游任务,每个stage中包含W-MSA
模块,将特征图划分成了多个不相交的窗体(Window),且MSA注意力交互只在每个窗体(Window)内进行。相对于ViT对全局进行Multi-Head Self-Attention能够减少计算量,尤其是在浅层特征图分辨率很大的时候。然而W-MSA阻碍不同窗口之间的信息传递,所以文章也提出了SW-MSA
模块,通过此方法能够实现跨窗口的信息交互;同时在不同stage间作者提出了Patch Mergring
下采样方法实现对token的下采样。
模型实现代码结构如Fig.2所示,下文会对具体模块和代码部分进行分析解读。
文章结构图当中的Patch Paritition和Linear Embedding部分实际由PatchEmbed类来实现。初始化类时对patch_size尺寸、输入图像通道数in_channels和embed_dim维度进行定义。通过Conv卷积操作来实现嵌入。卷积核大小和步长都是patch_size,卷积核个数为embed_dim。输入维度为(B,C,H,W)
,输出维度为(B,H'*W',embed_dim)
。代码如下:
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
"""
def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size # 每块patch的尺寸,也是卷积核尺寸和步长
self.in_chans = in_c # 输入图像的维度,通道数
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
_, _, H, W = x.shape
# padding
# 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
if pad_input:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
0, self.patch_size[0] - H % self.patch_size[0],
0, 0))
# 下采样patch_size倍
x = self.proj(x)
_, _, H, W = x.shape
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
x = x.flatten(2).transpose(1, 2) # 展平并交换维度
x = self.norm(x)
return x, H, W
除第一stage以外,每个stage阶段在Swin Transformer Block前都需进行Patch Merging操作,主要目的是完成下采样生成不同尺度的特征,具体步骤如Fig.3所示。对于输入维度为(B,H*W,C)
的特征,实现方法是利用间隔采样生成四组宽高较输入减半的patch,而后对四组patch进行concatenate拼接,再通过Layer_Norm和线性映射层得到的输出维度为(B,H/2*W/2,2C)
。代码如下:
class PatchMerging(nn.Module):
r"""
Patch Merging Layer.
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""
x: B, H*W, C
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C) # 对应步骤1,对输入tensor改变形状,还原为(B, H, W, C)
# 如果输入feature map的H,W不是2的整数倍,需要进行padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
#上述对应步骤2,隔行采样生成四组patch,分别为x0,x1,x2,x3
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
# 对应步骤3,对四组patch进行拼接
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
x = self.norm(x) # 对应步骤4,进行LN层计算
x = self.reduction(x) # [B, H/2*W/2, 2*C] #对应步骤5,线性映射改变通道数
return x
Swin Transformer Block相较于基本ViT,除了常规的MLP多层感知机、LN层及残差连接方法外,主要引入了W-MSA和SW-MSA,基本思想是只在固定窗体区域内对元素进行注意力计算,同时为了使得全局信息能够有效交互,采取滑动窗口注意力计算SW-MSA。同时为了解决滑动窗口引入的问题还引入了相对位置偏置和滑动窗口分块掩码。
SW-MSA为了使得不同Window区域内部的元素进行注意力交互,为窗口Window添置了一个偏移量。具体的窗口滑动方法如Fig.4所示。为了解决窗口滑动后跨图像区域-如图4中A和B区域交互的不合理情况,设置和mask掩码以消除交互结果,只让在原始 图像上真实属于同一区域的特征进行交互。具体做法是创建mask,源自不同区域的位置索引值赋值-100,在计算时通过softmax最终趋近于0忽略不计。创建掩码的过程如下:
def create_mask(self, x, H, W):
# calculate attention mask for SW-MSA
# 保证Hp和Wp是window_size的整数倍
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
# 拥有和feature map一样的通道排列顺序,方便后续window_partition
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # [1, Hp, Wp, 1]
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
# 通过切片划分出来源不同patch的元素
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# 通过上述,对来源于不同区域的元素分别赋予不同的编码
mask_windows = window_partition(img_mask, self.window_size) # [nW, Mh, Mw, 1]
# 按照window_size划分为窗口格式
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # [nW, Mh*Mw]
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
# [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1] = [nW, Mh*Mw, Mh*Mw]
# 来源相同区域的元素位置索引值为0,不同区域的为除0以外的其他值
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
# 对为0的位置赋值0,非0位置赋值-100
return attn_mask
而后则构建Swin Transformer Block类并实现W-MSA和SW-MSA,结构如Fig.5所示。根据传入参数shift_size来判断MSA种类。构建MSA后完成:
窗口化→W-MSA/SW-MSA计算→窗口复原特征→FFN+残差连接
进而继续完成FFN和残差连接的计算。至此整个Block构建完成。Swin Transformer Block代码如下:
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, attn_mask):
H, W = self.H, self.W
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
# 把feature map给pad到window size的整数倍
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=attn_mask) # [nW*B, Mh*Mw, C]
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # [nW*B, Mh, Mw, C]
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # [B, H', W', C]
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
# 把前面pad的数据移除掉
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # [2*Mh-1 * 2*Mw-1, nH]
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
# coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
coords = torch.stack(torch.meshgrid([coords_h, coords_w],))
# [2, Mh, Mw] 对生成的两个tensor进行拼接
coords_flatten = torch.flatten(coords, 1) # [2, Mh*Mw] 展平处理
# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # [2, Mh*Mw, Mh*Mw] 相减求相对位置编码矩阵
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # [Mh*Mw, Mh*Mw, 2]
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # [Mh*Mw, Mh*Mw]
self.register_buffer("relative_position_index", relative_position_index) # 将参数放置于模型当中
# 以上代码生成relative_position_index
作为ICCV 2021 Best Paper,本文主要将分层stage结构引入了ViT领域当中,引入窗体自注意力在极大降低计算量的同时利用SW-MSA实现跨窗特征的交互,有效利用全局信息。在代码方面,分区mask掩码和相对位置偏置的部分十分惊艳(微软炫技了),但理解难度也较大。开创了分层ViT和滑窗注意力的领域。