ViT使用纯Transformer结构来做图像分类任务,它开创了Transformer能够在CV领域有效工作的先河。ViT验证了在大规模数据集上进行预训练,然后迁移到小规模数据集上,Transformer性能要比CNN好。由于缺少CNN自带的归纳偏置(平移不变形和局部性),ViT在ImageNet数据集(中型数据集)上表现没有CNN好,Transformer需要充足的图像数据学习。
我们以ViT的base模型为例来描述ViT的流程。Transformer结构不能直接处理图像,首先需要将2D的图像分块(patch),CV中的patch可以近似看做NLP中的token,每块的大小为 P ∗ P ∗ C P*P*C P∗P∗C。假设一个大小为 224 ∗ 224 ∗ 3 224*224*3 224∗224∗3的图像,每块的大小为 16 ∗ 16 ∗ 3 16*16*3 16∗16∗3,那么此张图片将有 224 16 ∗ 224 16 = 14 ∗ 14 = 196 \frac{224}{16}*\frac{224}{16}=14*14=196 16224∗16224=14∗14=196个块。图像预处理将一个2D的 224 ∗ 224 ∗ 3 224*224*3 224∗224∗3的图像展平为1D的 196 ∗ 768 196*768 196∗768大小的向量。接下来,进行图像块嵌入(类似于NLP中的词嵌入),就是ViT论文中的 E E E, E E E的维度是 768 ∗ 768 768*768 768∗768。映射后的向量维度仍然为 196 ∗ 768 196*768 196∗768。类似于BERT中的[class] token,ViT中加入了一个可以学习的嵌入,如下图中的第0位置,它经过Transformer 编码器后的输出作为图像表示 y y y,用于分类。就这样,嵌入向量就由 196 ∗ 768 196*768 196∗768变为 197 ∗ 768 197*768 197∗768。为了保持输入图像块之间的空间位置信息,对映射后的向量添加了一个位置编码信息,如下图一中的0-9数字。位置编码采用的是1-D的可学习嵌入变量,论文中实验验证2-D的位置编码和1-D的位置编码结果近似。
Swim Transformer是特为视觉领域设计的一种分层Transformer结构。Swin 的两大特性是滑动窗口和分层表示。滑动窗口在局部不重叠的窗口中计算自注意力,并允许跨窗口连接。分层结构允许模型适配不同尺度的图片,并且计算复杂度与图像大小呈线性关系。
ViT只能够做分类,Swin Transformer借鉴了CNN的分层结构,如下图二(a),不仅能够做分类,还能够和CNN一样扩展到下游任务,比如检测,分割等。Swim Transformer不同于标准的Transformer结构,它计算不重叠窗口中的自注意力。为了解决窗口和窗口之间无连接的问题,Swin提出了移位窗口分割方法,见下图二(b),W-MSA和SW-MSA在连续的Swin Transformer blocks中交替出现,见下图二©。因此不论哪个Swim Transformer版本,都有偶数个blocks。
下图二(d)展示了Swin Transformer的tiny版本(Swin-T)。首先,它通过一个patch分割模块将输入的RGB图像分割成不重叠的patches,每个patch被看做是一个“token”,在论文中,patch size大小为 4 × 4 4 \times 4 4×4,每个patch的特征维度为 4 × 4 × 3 = 48 4 \times 4 \times 3 = 48 4×4×3=48。对于一个 H × W H \times W H×W大小的RGB图像,经过patch分割模块之后表示为 H 4 × W 4 × 48 \frac{H}{4} \times \frac{W}{4} \times 48 4H×4W×48。紧接着一个线性嵌入层将此原始值特征映射为一个任意的维度,记为 C C C。Swin Transformer block 应用到这些patch token上。线性映射加上Swin Transformer block,被称为“Stage 1”。为了得到分层表示,随着网络层数的加深,token的数量通过patch merging layers减少。第一个patch merging layer层连接每组 2 × 2 2 \times 2 2×2相邻patches的特征,然后在维度为 4 C 4C 4C的连接特征上应用线性层降维到 2 C 2C 2C。“Stage 2”,“Stage 3”和“Stage 4”由patch merging layer和Swin Transformer block组成,因此每个阶段的尺寸减少 2 2 2倍,维度增大 2 2 2倍,以至于“Stage 4”的输出特征为 H 32 × W 32 × 8 C \frac{H}{32} \times \frac{W}{32} \times 8C 32H×32W×8C 。
图二:Swin Transformer 架构
一张 H × W H \times W H×W大小的图中,里面包含 H × W H \times W H×W个像素。一个patch就是图像中的 N × N N \times N N×N个像素区域;一个window是由 M × M M \times M M×M个patches组成的。由上图所示,图像被分成 4 4 4个窗口,每个窗口包含 4 × 4 = 16 4 \times 4 =16 4×4=16个patches。假设每个patch的大小为 4 × 4 4 \times 4 4×4,则每个patch的向量维度为 4 × 4 × 3 = 48 4 \times 4 \times 3 = 48 4×4×3=48。每个patch可以看做NLP中的“token”,仿照NLP的词嵌入,将patch映射为维度为 C C C的向量。
下述代码展示了如何将图像如何进行patch嵌入。假设一张 224 × 224 × 3 224 \times 224 \times 3 224×224×3的图片,patch size大小为 4 × 4 4 \times 4 4×4,经过一个卷积层(第26行代码)之后的输出shape为 ( B , C , 56 , 56 ) (B, C, 56,56) (B,C,56,56),展平后两项,并对换后两项的位置,最后嵌入的输出为 ( B , 56 ∗ 56 , C ) (B,56*56,C) (B,56∗56,C)
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size # 图像尺寸
self.patch_size = patch_size # patch大小
self.patches_resolution = patches_resolution
# patches 数量
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans # 输入图像通道,默认3
self.embed_dim = embed_dim # 映射输出通道
# 线性映射
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # Batch, embed_dim, img_size[0] // patch_size[0], img_size[1] // patch_size[1]
if norm_layer is not None:
self.norm = norm_layer(embed_dim) # 正则
else:
self.norm = None
#
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# 假设 (H,W)=(224,224),那么(Ph,Pw)=(224/4=56, 224/4=56)
# self.proj(x)输出shape为(B, C, 56, 56)
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw=(56*56) C
if self.norm is not None:
x = self.norm(x)
return x
patch merging layers是Swim Transformer分层结构的重要组件。它连接每组 2 × 2 2 \times 2 2×2相邻patches的特征,然后在维度为 4 C 4C 4C的连接特征上应用线性层降维到 2 C 2C 2C。下图四展示了patch merging layers如何将一个 h × w × 1 h \times w \times 1 h×w×1的特征如何转换为 h 2 × w 2 × 4 \frac{h}{2} \times \frac{w}{2} \times 4 2h×2w×4。将$h \times w 特 征 特征 特征x$划分为大小为 2 × 2 2 \times 2 2×2的组,提取每组相同位置的特征得到 x 0 , x 1 , x 2 , x 3 x_{0}, x_{1}, x_{2},x_{3} x0,x1,x2,x3(下述代码第28-31行),合并 x 0 , x 1 , x 2 , x 3 x_{0}, x_{1}, x_{2},x_{3} x0,x1,x2,x3,通道数量则扩大 4 4 4倍(下述代码第32行),然后再通过线性层降维(下述代码第14和36行)。
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) # 降维 4dim ---> 2dim
self.norm = norm_layer(4 * dim) # 归一化层
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x) # 归一化
x = self.reduction(x) # 降维
return x
标准的Transformer架构是全局自注意力,它计算某个token和其他token之间的自注意力,计算复杂度和token的数量呈平方关系。视觉图像的token数量要多于语言中的单词token数量,Transformer在视觉中会耗费更多的资源,尤其对于高质量图像,计算复杂度会非常大。基于此种情况,Swim Transformer采用基于窗口的自注意替换标准的全局注意力。
将一张patches数量为 h × w h \times w h×w的图像拆分成不重叠的窗口,每个窗口包含 M × M M \times M M×M个patches。我们先回忆一下标准Transformer中的多头自注意力。假设输入为 x x x,将 x x x进行线性嵌入得到 Q , K , V Q, K, V Q,K,V三个向量, Q Q Q和 K K K两个向量相乘计算得到Attention,然后Attention与向量 V V V相乘之后再线性映射得到输出。假设patches的数量为 N N N,通道数为 C C C,那么两次线性计算复杂度为 4 N C 2 4NC^{2} 4NC2, Q , K , V Q, K, V Q,K,V的两次矩阵计算的复杂度为 2 N 2 C 2N^{2}C 2N2C。那么对于标准的多头自注意力,它的计算复杂度为 Ω ( M S A ) = 4 h w C 2 + 2 ( h w ) 2 C \Omega\left( MSA \right) = 4hwC^{2} + 2\left( hw \right)^{2}C Ω(MSA)=4hwC2+2(hw)2C;一个窗口的自注意力计算复杂度为 4 M 2 C 2 + 2 M 4 C 4M^{2}C^{2} + 2M^{4}C 4M2C2+2M4C,此张图片总共有 h M × h M \frac{h}{M} \times \frac{h}{M} Mh×Mh个窗口,那么总的基于窗口的自注意力的计算复杂度为 Ω ( W − M S A ) = 4 h w C 2 + 2 M 2 h w C \Omega\left( W-MSA \right) = 4hwC^{2} + 2M^{2} hw C Ω(W−MSA)=4hwC2+2M2hwC。
对于 224 × 224 224 \times 224 224×224大小的图片,每一个patches的大小为 4 × 4 4 \times 4 4×4,那么总共有 56 × 56 56 \times 56 56×56个patches。论文中默认 M = 7 M=7 M=7, Ω ( M S A ) \Omega\left( MSA \right) Ω(MSA)和 Ω ( W − M S A ) \Omega\left( W-MSA \right) Ω(W−MSA)计算复杂度相差在矩阵相乘的部分, 2 × 56 × 56 h w C 2 \times 56 \times 56 hwC 2×56×56hwC是 2 × 7 2 h w C 2 \times 7^{2} hwC 2×72hwC的近60倍。随着图片的尺寸越大,这个差距会越大。
论文在计算自注意力时引入了相对位置偏置(relative position bias),论文实验表明,相对位置偏置在ImageNet,CoCo和ADE20k数据集上的表现要优于不加偏置和使用绝对位置偏置。下述代码展示了带有相对位置偏置的窗口多头自注意力的前向过程。它支持窗口自注意力和移动窗口自注意力。窗口自注意力计算包含三个方面,常规多头自注意力,相对位置偏置的计算和移动窗口的掩码计算。常规多头自注意力有Transformer的基础就很好理解,难点在于相对位置偏置的计算和移动窗口的掩码计算。
A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T / d + B ) V Attention(Q,K,V) = SoftMax(QK^{T} / \sqrt{d}+ B) V Attention(Q,K,V)=SoftMax(QKT/d+B)V
# 通道注意力计算
def forward(self, x, mask=None):
"""
Attention(Q,K,V) = SoftMax(QK^{T}/sqrt(d) + Bias)V
x: 输入特征 shape: (num_windows*B, N, C)
mask: 掩码
"""
B_, N, C = x.shape # N=Wh*Ww 窗口里面的patches数量
# qkv.shape: (3, num_windows*B, self.num_heads, N, C // self.num_heads)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # shape: (num_windows*B, self.num_heads, N, C // self.num_heads)
# self.scale对应于公式中的sqrt(d)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # QK^{T}/sqrt(d) atten.shape: (num_windows*B, self.num_heads, N, N)
# 相对位置偏置
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH nH是head的数量
#relative_position_bias.shape=(nH, Wh*Ww, Wh*Ww)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0) # QK^{T}/sqrt(d) + Bias
# 掩码
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn) # SoftMax(QK^{T}/sqrt(d) + Bias)
attn = self.attn_drop(attn)
# x.shape=(num_windows*B, N, C)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # SoftMax(QK^{T}/sqrt(d) + Bias)V
x = self.proj(x) # 映射
x = self.proj_drop(x) # drop
return x
绝对位置编码是在进行自注意力计算之前为每个token添加一个可学习的参数,相对位置编码,是在进行自注意力计算时,在计算过程中添加一个可学习的相对位置参数。
相对位置偏置 B ∈ R M 2 × M 2 B \in \mathbb{R}^{M^{2} \times M^{2}} B∈RM2×M2,每一个轴的取值范围是 [ − M + 1 , M − 1 ] [-M+1, M-1] [−M+1,M−1]。计算自注意力时,每个token都要与其他位置上的token计算 Q K QK QK值。对于一个大小为 2 × 2 2\times2 2×2的窗口,位置1上的patch要与位置1,2,3,4的patch计算 Q K QK QK值,位置2上的patch要与位置1,2,3,4上的patch计算 Q K QK QK值,… ,那么其他位置相对于当前位置都有一个偏移量。下图5中展示了relative_coords(下述代码第8行)其他位置相当于当前位置的偏移量(按列看),为了便于后续的计算,对每个元素都加上偏移量,使其从零开始,如下述代码第9和第10行。由于(0,1)和(1,0),(-1,0)和(0,-1)它们取和后的总偏移量结果一样,因为对某一列坐标进行乘法变换,如下述代码第11行,最后再取和得到总的偏移量relative_position_index。至此,相对位置的下标取值范围为 [ 0 , 8 ] [0,8] [0,8],可由一个 ( 2 M − 1 ) ∗ ( 2 M − 1 ) (2M-1)*(2M-1) (2M−1)∗(2M−1)大小的矩阵表示,参数化这个更小尺寸的偏置矩阵 B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \hat{B} \in \mathbb{R}^{\left( 2M-1\right) \times \left( 2M-1\right)} B^∈R(2M−1)×(2M−1),那么 B B B的值就可以从 B ^ \hat{B} B^中提取。
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
# 以下用Wh * Ww 替代 self.window_size[0] * self.window_size[1]
coords_h = torch.arange(self.window_size[0]) # [0,1,...,Wh-1]
coords_w = torch.arange(self.window_size[1]) # [0,1,...,Ww-1]
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
# 将relative_position_index注册为一个不参与网络学习的变量。
self.register_buffer("relative_position_index", relative_position_index)
# 使用截断正态分布中提取的值填充输入张量。
trunc_normal_(self.relative_position_bias_table, std=.02)
# forward函数中相对未知的偏置
'''self.relative_position_index是计算出不可学习的量 第17行
self.relative_position_index.shape=(Wh*Ww, Wh*Ww) 第15行
self.relative_position_bias_table.shape=(2*Wh-1 * 2*Ww-1, nH) 第2行
self.relative_position_index矩阵中的所有值都是从self.relative_position_bias_table中取的
'''
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH nH是head的数量
Swim 计算每个窗口中的自注意,窗口与窗口之间无计算,失去了Transformer处理全局信息的特征。为此,Swim Transformer提出了移位窗口分割方法。假设有 8 × 8 8 \times 8 8×8 patches的图片,其中一个窗口包含 4 × 4 4 \times 4 4×4个patches,那么将有 2 × 2 2 \times 2 2×2个窗口。现在将左上角的窗口向左下移位 2 2 2个patches,四个窗口将被重新划分为 9 9 9个大小不一的窗口,如图六所示,只有标号为4的窗口和原窗口大小一致。
最直接的想法是对小窗口进行padding,并在计算的时候屏蔽掉填充的值。但是,自注意力计算将由四个被扩展到九个,计算多了2.25倍。为了不增加计算量,论文中提出了循环移位(cyclic-shift)算法,如图六所示,将编号3,6的窗口移位到编号5,8的窗口下面,将编号0,1的窗口移位到编号6,7的窗口左面,将编号为0的窗口,从左上角移位到右下角。这样就可以重新拼凑出 2 × 2 2 \times 2 2×2 (4,(7,1),(3,5),(0,2,6,8))个窗口。拼凑出的窗口在原图中属于不同的位置,不相连,以标号为0,2,6,8窗口组成的大窗口为例,这四个小窗口分别位于原图的四个顶点,关联性极低,因此,在计算窗口注意力时,需要掩码机制,只能计算相同子窗口的自注意力,不同窗口的自注意力结果要为0。标号为0,2,6,8窗口,在计算窗口自注意力时,窗口0中的每一个patch分别需要和窗口0,2,6,8中的每一patch进行自注意力计算,那么窗口0中的patch与窗口0中的patch的自注意力是有用的,但是窗口0中的patch与窗口2,6,8中的patch的自注意力需要设为0。我们回忆一下Attention的计算公式, A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T / d + B ) V Attention(Q,K,V) = SoftMax(QK^{T} / \sqrt{d}+ B) V Attention(Q,K,V)=SoftMax(QKT/d+B)V,自注意力计算最后需要Softmax函数。在不同窗口的自注意力值上添加 − 100 -100 −100(下图代码第20行,mask赋值-100;第27行,将mask添加到自注意力值上,然后再进行softmax计算),在softmax计算过程中, − 100 -100 −100会无限趋近于0,从而达到归0的效果。
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
# 给每个窗口上索引
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# 拆分窗口
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # nW, window_size * window_size
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size * window_size, window_size * window_size
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) #
else:
attn_mask = None
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # 加mask
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
下图展示了循环移位之后的窗口组成和mask值的分布情况。
Swin Transformer有四种形式,分别命名为Swin-T,Swin-S,Swin-B和Swin-L。以Swin-T为基础模型版本,Swin-T,Swin-S,和Swin-L分别是基础模型的 0.25 × 0.25\times 0.25×, 0.5 × 0.5\times 0.5×和 2 × 2\times 2×倍。这四种模型的架构如图8所示。
ViT模型将Transformer结构应用到视觉领域,但是仍然还受限于图片的尺寸大小。Swin引入移动窗口和分层结构,使得自注意力在视觉领域的计算复杂度能与图片大小成线性关系。Swin吸取了CNN和Transformer的优点,在ImageNet-1k的数据集上也能取得SOTA效果,相比于ViT模型,降低了数据的需求量。