自从Transformer[1]在NLP任务上取得突破性的进展之后,业内一直尝试着把Transformer用于在CV领域。之前的若干尝试,例如iGPT[2],ViT[3]都是将Transformer用在了图像分类领域,目前这些方法都有两个非常严峻的问题
本文提出的Swin Transformer [4]解决了这两个问题,并且在分类,检测,分割任务上都取得了SOTA的效果。Swin Transformer的最大贡献是提出了一个可以广泛应用到所有计算机视觉领域的backbone,并且大多数在CNN网络中常见的超参数在Swin Transformer中也是可以人工调整的,例如可以调整的网络块数,每一块的层数,输入图像的大小等等。该网络架构的设计非常巧妙,是一个非常精彩的将Transformer应用到图像领域的结构,值得每个AI领域的人前去学习。
在Swin Transformer之前的ViT和iGPT,它们都使用了小尺寸的图像作为输入,这种直接resize的策略无疑会损失很多信息。与它们不同的是,Swin Transformer的输入是图像的原始尺寸,例如ImageNet的224*224。另外Swin Transformer使用的是CNN中最常用的层次的网络结构,在CNN中一个特别重要的一点是随着网络层次的加深,节点的感受野也在不断扩大,这个特征在Swin Transformer中也是满足的。Swin Transformer的这种层次结构,也赋予了它可以像FPN[6],U-Net[7]等结构实现可以进行分割或者检测的任务。Swin Transformer和ViT的对比如图1。
本文将结合它的pytorch源码对这篇论文的算法细节以及代码实现展开详细介绍,并对论文中解释模糊的地方具体分析。读完此文,你将完全了解清楚Swin Transfomer的结构细节以及设计动机,现在我们开始吧。
Swin Transformer共提出了4个网络框架,它们从小到大依次是Swin-T,Swin-S, Swin-B和Swin-L,为了绘图简单,本文以最简单的Swin-T作为示例来讲解,Swin-T的结构如图2所示。Swin Transformer最核心的部分便是4个Stage中的Swin Transformer Block,它的具体在如图3所示。
class SwinTransformer(nn.Module):
def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7, downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
super().__init__()
self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0], downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim, window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1], downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim, window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2], downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim, window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3], downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim, window_size=window_size, relative_pos_embedding=relative_pos_embedding)
self.mlp_head = nn.Sequential(
nn.LayerNorm(hidden_dim * 8),
nn.Linear(hidden_dim * 8, num_classes)
)
def forward(self, img):
x = self.stage1(img)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x) # (1, 768, 7, 7)
x = x.mean(dim=[2, 3]) # (1,768)
return self.mlp_head(x)
从源码中我们可以看出Swin Transformer的网络结构非常简单,由4个stage和一个输出头组成,非常容易扩展。Swin Transformer的4个Stage的网络框架的是一样的,每个Stage仅有几个基本的超参来调整,包括隐层节点个数,网络层数,多头自注意的头数,降采样的尺度等,这些超参的在源码的具体值如下面片段,本文也会以这组参数对网络结构进行详细讲解。
net = SwinTransformer(
hidden_dim=96,
layers=(2, 2, 6, 2),
heads=(3, 6, 12, 24),
channels=3,
num_classes=3,
head_dim=32,
window_size=7,
downscaling_factors=(4, 2, 2, 2),
relative_pos_embedding=True
)
在图2中,输入图像之后是一个Patch Partition,再之后是一个Linear Embedding层,这两个加在一起其实就是一个Patch Merging层(至少上面的源码中是这么实现的)。这一部分的源码如下:
class PatchMerging(nn.Module):
def __init__(self, in_channels, out_channels, downscaling_factor):
super().__init__()
self.downscaling_factor = downscaling_factor
self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)
def forward(self, x):
b, c, h, w = x.shape
new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
x = self.patch_merge(x) # (1, 48, 3136)
x = x.view(b, -1, new_h, new_w).permute(0, 2, 3, 1) # (1, 56, 56, 48)
x = self.linear(x) # (1, 56, 56, 96)
return x
Patch Merging的作用是对图像进行降采样,类似于CNN中Pooling层。Patch Merging是主要是通过nn.Unfold
函数实现降采样的,nn.Unfold
的功能是对图像进行滑窗,相当于卷积操作的第一步,因此它的参数包括窗口的大小和滑窗的步长。根据源码中给出的超参我们知道这一步降采样的比例是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rLLlEE8U-1638367362402)(https://www.zhihu.com/equation?tex=4)] ,因此经过nn.Unfold
之后会得到 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2n8HrNqn-1638367362404)(https://www.zhihu.com/equation?tex=%5Cfrac%7BH%7D%7B4%7D+%5Ctimes+%5Cfrac%7BW%7D%7B4%7D+%3D+%5Cfrac%7B224%7D%7B4%7D+%5Ctimes+%5Cfrac%7B224%7D%7B4%7D+%3D+3136)] 个长度为 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YnMEPlpo-1638367362405)(https://www.zhihu.com/equation?tex=4%5Ctimes4%5Ctimes3+%3D+48)] 的特征向量,其中 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xb7KFgnR-1638367362406)(https://www.zhihu.com/equation?tex=3)] 是输入到这个stage的Feature Map的通道数,第一个stage的输入是RGB图像,因此通道数为3,表示为式(1)。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-V4UVfL5g-1638367362407)(https://www.zhihu.com/equation?tex=%5Cmathbf%7Bz%7D%5E0+%3D+%5Ctext%7BMLP%7D%28%5Ctext%7BUnfold%7D%28%5Ctext%7BImage%7D%29%29+%5Ctag1)]
接着的view
和permute
是将得到的向量序列还原到 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fMF2Eqra-1638367362407)(https://www.zhihu.com/equation?tex=56%5Ctimes56)] 的二维矩阵,linear
是将长度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MAY7kJ0o-1638367362408)(https://www.zhihu.com/equation?tex=48)] 的特征向量映射到out_channels
的长度,因此stage-1的Patch Merging的输出向量维度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-i8BX9OMX-1638367362409)(https://www.zhihu.com/equation?tex=%2856%2C56%2C96%29)] ,对比源码的注释,这里省略了第一个batch为 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cRLaAuV9-1638367362409)(https://www.zhihu.com/equation?tex=1)] 的维度。
可以看出Patch Partition/Patch Merging起到的作用像是CNN中通过带有步长的滑窗来降低分辨率,再通过 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-clFNFpe2-1638367362410)(https://www.zhihu.com/equation?tex=1%5Ctimes1)] 卷积来调整通道数。不同的是在CNN中最常使用的降采样的最大池化或者平均池化往往会丢弃一些信息,例如最大池化会丢弃一个窗口内的地响应值,而Patch Merging的策略并不会丢弃其它响应,但它的缺点是带来运算量的增加。在一些需要提升模型容量的场景中,我们其实可以考虑使用Patch Merging来替代CNN中的池化。
如我们上面分析的,图2中的Patch Partition+Linaer Embedding就是一个Patch Marging,因此Swin Transformer的一个stage便可以看做由Patch Merging和Swin Transformer Block组成,源码如下。
class StageModule(nn.Module):
def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
relative_pos_embedding):
super().__init__()
assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'
self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
downscaling_factor=downscaling_factor)
self.layers = nn.ModuleList([])
for _ in range(layers // 2):
self.layers.append(nn.ModuleList([
SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
]))
def forward(self, x):
x = self.patch_partition(x)
for regular_block, shifted_block in self.layers:
x = regular_block(x)
x = shifted_block(x)
return x.permute(0, 3, 1, 2)
Swin Transformer Block是该算法的核心点,它由窗口多头自注意层(window multi-head self-attention, W-MSA)和移位窗口多头自注意层(shifted-window multi-head self-attention, SW-MSA)组成,如图3所示。由于这个原因,Swin Transformer的层数要为2的整数倍,一层提供给W-MSA,一层提供给SW-MSA。
图3:Swin Transformer Block的网络结构
从图3中我们可以看出输入到该stage的特征 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HUnWp7Kq-1638367362411)(https://www.zhihu.com/equation?tex=%5Cmathbf%7Bz%7D%5E%7Bl-1%7D)] 先经过LN进行归一化,再经过W-MSA进行特征的学习,接着的是一个残差操作得到 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uPquOMnT-1638367362412)(https://www.zhihu.com/equation?tex=%5Chat%7B%5Cmathbf%7Bz%7D%7D%5El)] 。接着是一个LN,一个MLP以及一个残差,得到这一层的输出特征 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4GOKdNnp-1638367362413)(https://www.zhihu.com/equation?tex=%5Cmathbf%7Bz%7D%5El)] 。SW-MSA层的结构和W-MSA层类似,不同的是计算特征部分分别使用了SW-MSA和W-MSA,可以从上面的源码中看出它们除了shifted
的这个bool值不同之外,其它的值是保持完全一致的。这一部分可以表示为式(2)。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NghHIzun-1638367362413)(https://www.zhihu.com/equation?tex=%5Cbegin%7Baligned%7D+%5Chat%7B%5Cmathbf%7Bz%7D%7D%5El+%26+%3D+%5Ctext%7BW-MSA%7D%28%5Ctext%7BLN%7D%28%5Cmathbf%7Bz%7D%5E%7Bl-1%7D+%29%29+%2B+%5Cmathbf%7Bz%7D%5E%7Bl-1%7D+%5C%5C+%5Cmathbf%7Bz%7D%5E%7Bl%7D+%26+%3D+%5Ctext%7BMLP%7D%28%5Ctext%7BLN%7D%28%5Cmathbf%7B%5Chat%7Bz%7D%7D%5El%29%29+%2B+%5Cmathbf%7B%5Chat%7Bz%7D%7D%5E+l+%5C%5C+%5Chat%7B%5Cmathbf%7Bz%7D%7D%5E%7Bl%2B1%7D+%26+%3D+%5Ctext%7BSW-MSA%7D%28%5Ctext%7BLN%7D%28%5Cmathbf%7Bz%7D%5E%7Bl%7D+%29%29+%2B+%5Cmathbf%7Bz%7D%5E%7Bl%7D+%5C%5C+%5Cmathbf%7Bz%7D%5E%7Bl%2B1%7D+%26+%3D+%5Ctext%7BMLP%7D%28%5Ctext%7BLN%7D%28%5Cmathbf%7B%5Chat%7Bz%7D%7D%5E%7Bl%2B1%7D%29%29+%2B+%5Cmathbf%7B%5Chat%7Bz%7D%7D%5E%7Bl%2B1%7D+%5C%5C+%5Cend%7Baligned%7D+%5Ctag2)]
一个Swin Block的源码如下所示,和论文中图不同的是,LN层(PerNorm
函数)从Self-Attention之前移到了Self-Attention之后。
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class SwinBlock(nn.Module):
def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
super().__init__()
self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim, heads=heads, head_dim=head_dim, shifted=shifted, window_size=window_size, relative_pos_embedding=relative_pos_embedding)))
self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))
def forward(self, x):
x = self.attention_block(x)
x = self.mlp_block(x)
return x
窗口多头自注意力(Window Multi-head Self Attention,W-MSA),顾名思义,就是个在窗口的尺寸上进行Self-Attention计算,与SW-MSA不同的是,它不会进行窗口移位,它们的源码如下。我们这里先忽略shifted
为True
的情况,这一部分会放在1.6节去讲。
class WindowAttention(nn.Module):
def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
super().__init__()
inner_dim = head_dim * heads
self.heads = heads
self.scale = head_dim ** -0.5
self.window_size = window_size
self.relative_pos_embedding = relative_pos_embedding # (13, 13)
self.shifted = shifted
if self.shifted:
displacement = window_size // 2
self.cyclic_shift = CyclicShift(-displacement)
self.cyclic_back_shift = CyclicShift(displacement)
self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement, upper_lower=True, left_right=False), requires_grad=False) # (49, 49)
self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,pper_lower=False, left_right=True), requires_grad=False) # (49, 49)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
if self.relative_pos_embedding:
self.relative_indices = get_relative_distances(window_size) + window_size - 1
self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
else:
self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))
self.to_out = nn.Linear(inner_dim, dim)
def forward(self, x):
if self.shifted:
x = self.cyclic_shift(x)
b, n_h, n_w, _, h = *x.shape, self.heads # [1, 56, 56, _, 3]
qkv = self.to_qkv(x).chunk(3, dim=-1) # [(1,56,56,96), (1,56,56,96), (1,56,56,96)]
nw_h = n_h // self.window_size # 8
nw_w = n_w // self.window_size # 8
# 分成 h/M * w/M 个窗口
q, k, v = map( lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d', h=h, w_h=self.window_size, w_w=self.window_size), qkv)
# q, k, v : (1, 3, 64, 49, 32)
# 按窗口个数的self-attention
dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale # (1,3,64,49,49)
if self.relative_pos_embedding:
dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
else:
dots += self.pos_embedding
if self.shifted:
dots[:, :, -nw_w:] += self.upper_lower_mask
dots[:, :, nw_w - 1::nw_w] += self.left_right_mask
attn = dots.softmax(dim=-1) # (1,3,64,49,49)
out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)', h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w) # (1, 56, 56, 96) # 窗口合并
out = self.to_out(out)
if self.shifted:
out = self.cyclic_back_shift(out)
return out
在forward
函数中首先计算的是Transformer中介绍的 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sgFYmJ9J-1638367362414)(https://www.zhihu.com/equation?tex=Q)] , [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zX5pVQIV-1638367362415)(https://www.zhihu.com/equation?tex=K)] , [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sEgd4AA2-1638367362416)(https://www.zhihu.com/equation?tex=V)] 三个特征。所以to_qkv()
函数就是一个线性变换,这里使用了一个实现小技巧,即只使用了一个一层隐层节点数为inner_dim*3
的线性变换,然后再使用chunk(3)
操作将它们切开。因此qkv
是一个长度为3的Tensor,每个Tensor的维度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fHME9T6t-1638367362416)(https://www.zhihu.com/equation?tex=%2856%2C56%2C96%29)] 。
之后的map函数是实现W-MSA中的W最核心的代码,它是通过einops
的rearrange
实现的。einops是一个可读性非常高的实现常见矩阵操作的python包,例如矩阵转置,矩阵复制,矩阵reshape等操作。最终通过这个操作得到了3个独立的窗口的权值矩阵,它们的维度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-g9yy6v1a-1638367362417)(https://www.zhihu.com/equation?tex=%283%2C64%2C49%2C32%29)] ,这4个值的意思分别是:
Swin Transformer将计算区域控制在了以窗口为单位的策略极大减轻了网络的计算量,将复杂度降低到了图像尺寸的线性比例。传统的MSA和W-MSA的复杂度分别是:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-y6UCPTGT-1638367362422)(https://www.zhihu.com/equation?tex=%5Cbegin%7Baligned%7D+%5COmega%28%5Ctext%7BMSA%7D%29+%26+%3D+4hwC%5E2+%2B+2%28hw%29%5E2C+%5C%5C+%5COmega%28%5Ctext%7BW-MSA%7D%29+%26+%3D+4hwC%5E2+%2B+2M%5E2hwC+%5Cend%7Baligned%7D+%5Ctag3)]
(3)式的计算忽略了softmax的占用的计算量,这里以 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1mH5weSM-1638367362422)(https://www.zhihu.com/equation?tex=%5COmega%28%5Ctext%7BMSA%7D%29)] 为例,它的具体构成如下:
to_qkv()
函数,即用于生成 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ocIiTAWF-1638367362423)(https://www.zhihu.com/equation?tex=Q%2CK%2CV)] 三个特征向量:其中 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wsKXrHvd-1638367362424)(https://www.zhihu.com/equation?tex=Q%3Dx%5Ctimes+W%5EQ%2C+K%3Dx%5Ctimes+W%5EK%2C+V%3Dx%5Ctimes+W%5EV)] 。 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4lqzSIri-1638367362424)(https://www.zhihu.com/equation?tex=x)] 的维度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sGo0fUzA-1638367362425)(https://www.zhihu.com/equation?tex=%28hw%2CC%29)] , [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qevNIZtK-1638367362426)(https://www.zhihu.com/equation?tex=W)] 的维度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3b4k5mHi-1638367362426)(https://www.zhihu.com/equation?tex=%28C%2CC%29)] ,那么这三项的复杂度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zElpFzbB-1638367362427)(https://www.zhihu.com/equation?tex=3hwC%5E2)] ;to_out()
函数:它的复杂度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XTIV1fX7-1638367362437)(https://www.zhihu.com/equation?tex=hwC%5E2)] 。通过Transformer的计算公式(4),我们可以有更直观一点的理解,在Transformer一文中我们介绍过Self-Attention是通过点乘的方式得到Query矩阵和Key矩阵的相似度,即(4)式中的 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-H4DYql4S-1638367362438)(https://www.zhihu.com/equation?tex=QK%5ET)] 。然后再通过这个相似度匹配Value。因此这个相似度的计算时通过逐个元素进行点乘计算得到的。如果比较的范围是一个图像,那么计算的瓶颈就是整个图的逐像素比较,因此复杂度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VU2pIcLK-1638367362439)(https://www.zhihu.com/equation?tex=%28hw%29%5E2C)] 。而W-MSA是在窗口内的逐像素比较,因此复杂度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mGDa6dcr-1638367362440)(https://www.zhihu.com/equation?tex=M%5E2+hwC)] ,其中 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YPKk4SOC-1638367362441)(https://www.zhihu.com/equation?tex=M)] 是W-MSA的窗口的大小。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-feEZNlwo-1638367362441)(https://www.zhihu.com/equation?tex=Z+%3D+%5Ctext%7Bsoftmax%7D%5Cleft%28%5Cfrac%7BQK%5ET%7D%7B%5Csqrt%7Bd_k%7D%7D%5Cright%29+V+%5Ctag4+)]
回到代码,接着的dots
变量便是我们刚刚介绍的 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-HvVYustM-1638367362442)(https://www.zhihu.com/equation?tex=QK%5ET)] 操作。接着是加入相对位置编码,我们放到最后介绍。接着的attn
以及einsum
便是完成了式(4)的整个流程。然后再次使用rearrange
将维度再调整回 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uXy4zlKt-1638367362443)(https://www.zhihu.com/equation?tex=%2856%2C56%2C96%29)] 。最后通过to_out
将维度调整为超参设置的输出维度的值。
这里我们介绍一下W-MSA的相对位置编码,首先这个位置编码是加在乘以完归一化尺度之后的dots
变量上的,因此 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6To1Mmly-1638367362443)(https://www.zhihu.com/equation?tex=Z)] 的计算方式变为式(5)。因为W-MSA是以窗口为单位进行特征匹配的,因此相对位置编码的范围也应该是以窗口为单位,它的具体实现见下面代码。相对位置编码的具体思想参考UniLMv2[8]。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-433qvR7X-1638367362444)(https://www.zhihu.com/equation?tex=Z+%3D+%5Ctext%7Bsoftmax%7D%5Cleft%28%5Cfrac%7BQK%5ET%7D%7B%5Csqrt%7Bd_k%7D%7D+%2B+B+%5Cright%29+V+%5Ctag5)]
def get_relative_distances(window_size):
indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
distances = indices[None, :, :] - indices[:, None, :]
return distances
单独的使用W-MSA得到的网络的建模能力是非常差的,因为它将每个窗口当做一个独立区域计算而忽略了窗口之间交互的必要性,基于这个动机,Swin Transformer提出了SW-MSA。
SW-MSA的的位置是接在W-MSA层之后的,因此只要我们提供一种和W-MSA不同的窗口切分方式便可以实现跨窗口的通信。SW-MSA的实现方式如图4所示。我们上面说过,输入到Stage-1的图像尺寸是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RmVzsArK-1638367362444)(https://www.zhihu.com/equation?tex=56+%5Ctimes+56)] 的(图4.(a)),那么W-MSA的窗口切分的结果如图4.(b)所示。那么我们如何得到和W-MSA不同的切分方式呢?SW-MSA的思想很简单,将图像各循环上移和循环左移半个窗口的大小,那么图4.©的蓝色和红色区域将分别被移动到图像的下侧和右侧,如图4.(d)。那么在移位的基础上再按照W-MSA切分窗口,就会得到和W-MSA不同的窗口切分方式,如图4.(d)中红色和蓝色分别是W-MSA和SW-MSA的切分窗口的结果。这一部分可以通过pytorch的roll
函数实现,源码中是CyclicShift
函数。
class CyclicShift(nn.Module):
def __init__(self, displacement):
super().__init__()
self.displacement = displacement
def forward(self, x):
return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
其中displacement
的值是窗口值除2。
这种窗口切分方式引入了一个新的问题,即在移位图像的最后一行和最后一列各引入了一块移位过来的区域,如图4.(d)。根据上面我们介绍的 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4O6LEnho-1638367362446)(https://www.zhihu.com/equation?tex=QK%5ET)] 用于逐像素计算相似度的自注意力机制,图像两侧的像素互相计算相似度是没有任何作用的,即只需要对比图4.(d)中的一个窗口中相同颜色的区域,我们以图4.(d)左下角的区域(1)为例来说明SW-MSA是怎么实现这个功能的。
区域(1)的计算如图5所示。首先一个 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-v9gygIr9-1638367362446)(https://www.zhihu.com/equation?tex=7%5Ctimes7)] 大小的窗口通过线性预算得到 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-g1XeeBOs-1638367362447)(https://www.zhihu.com/equation?tex=Q)] , [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZekBc5Vm-1638367362447)(https://www.zhihu.com/equation?tex=K)] , [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4b2ME3zb-1638367362448)(https://www.zhihu.com/equation?tex=V)] 三个权值,如我们介绍的,它的维度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vydGApFI-1638367362448)(https://www.zhihu.com/equation?tex=%2849%2C32%29)] 。在这个49中,前28个是按照滑窗的方式遍历区域(1)中的前48个像素得到的,后21个则是遍历区域(1)的下半部分得到的,此时他们对应的位置关系依旧保持上黄下蓝的性质。
接着便是计算 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AoPJLydi-1638367362449)(https://www.zhihu.com/equation?tex=QK%5ET)] ,在图中相同颜色区域的相互计算后会依旧保持颜色,而黄色和蓝色区域计算后会变成绿色,而绿色的部分便是无意义的相似度。在论文中使用了upper_lower_mask
将其掩码掉,upper_lower_mask
是由 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QxcZTvJz-1638367362450)(https://www.zhihu.com/equation?tex=0)] 和无穷大组成的二值矩阵,最后通过单位加之后得到最终的dots
变量。
upper_lower_mask
的计算方式如下。
mask = torch.zeros(window_size ** 2, window_size ** 2)
mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')
区域(2)的计算方式和区域(1)类似,不同的是区域(2)是循环左移之后的结果,如图6所示。因为(2)是左右排列的,因此它得到的[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Dh2yBjSf-1638367362451)(https://www.zhihu.com/equation?tex=Q)] , [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RaFPTex0-1638367362451)(https://www.zhihu.com/equation?tex=K)] , [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NOHIGNKC-1638367362452)(https://www.zhihu.com/equation?tex=V)]是条纹状的,即先逐行遍历,在这7行中,都会先遍历到4个黄的,然后再遍历到3个红的。两个条纹状的矩阵相乘后,得到的相似度矩阵是网络状的,其中橙色表示无效区域,因此需要网格状的掩码left_right_mask
来进行覆盖。
left_right_mask
的生成方式如下面代码。关于这两个掩码的值,你可以自己代入一些值来验证,你可以设置一下window_size
的值,然后displacement
的值设为window_size
的一半即可。
这一部分操作中,窗口移位和mask的计算是在WindowAttention
类中的第一个if shifted = True
中实现的。掩码的相加是在第二个if中实现的,最后一个if则是将图像再复原回原来的位置。
mask = torch.zeros(window_size ** 2, window_size ** 2)
mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
mask[:, -displacement:, :, :-displacement] = float('-inf')
mask[:, :-displacement, :, -displacement:] = float('-inf')
mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')
截止到这,我们从头到尾对Swin-T的stage-1进行了完成的梳理,后面3个stage除了几个超参以及图像的尺寸和stage-1不同之外,其它的结构均保持一致,这里不再赘述。
最后我们介绍一下Swin Transformer的输出层,在stage-4完成计算后,特征的维度是 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7c4A5vRw-1638367362454)(https://www.zhihu.com/equation?tex=%28768%2C7%2C7%29)] 。Swin Transformer先通过一个Global Average Pooling得到长度为768的特征向量,再通过一个LN和一个全连接得到最终的预测结果,如式(6)。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wvfHDUqQ-1638367362454)(https://www.zhihu.com/equation?tex=%5Chat%7By%7D+%3D+%5Ctext%7BMLP%7D%28%5Ctext%7BLN%7D%28%5Ctext%7BGAP%7D%28z%5E4%29%29%29+%5Ctag6)]
Swin Transformer共提出了4个不同尺寸的模型,它们的区别在于隐层节点的长度,每个stage的层数,多头自注意力机制的头的个数,具体值见下面代码。
def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)
def swin_s(hidden_dim=96, layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)
def swin_b(hidden_dim=128, layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)
def swin_l(hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs):
return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)
因为Swin Transformer是一个多阶段的网络框架,而且每一个阶段的输出也是一组Feature Map,因此可以非常方便的将其迁移到几乎所有CV任务中。作者的实验结果也表明Swin Transformer在检测和分割领域也达到了state-of-the-art的水平。
Swin Transformer是近年来为数不多的读起来让人兴奋的算法,它让人兴奋的点有三:
当然我们对Swin Transformer还是要站在一个客观的角度来评价的,虽然论文中说Swin Transformer是一个backbone,但是这个评价还为时尚早,因为
[1] Vaswani, Ashish, et al. “Attention is all you need.” arXiv preprint arXiv:1706.03762 (2017).
[2] Dosovitskiy, Alexey, et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020).
[3] Chen, Mark, et al. “Generative pretraining from pixels.” International Conference on Machine Learning. PMLR, 2020.
[4] Liu, Ze, et al. “Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.” arXiv preprint arXiv:2103.14030 (2021).
[5] Ba J L, Kiros J R, Hinton G E. Layer normalization[J]. arXiv preprint arXiv:1607.06450, 2016.
[6] T.-Y. Lin, P. Dollar, R. Girshick, K. He, B. Hariharan, and ´ S. Belongie. Feature pyramid networks for object detection. In CVPR, 2017. 2, 4, 5, 7
[7] Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation[C]//International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015: 234-241.
[8] Bao, Hangbo, et al. “Unilmv2: Pseudo-masked language models for unified language model pre-training.” International Conference on Machine Learning. PMLR, 2020.