SegFormer
论文链接: https://arxiv.org/abs/2105.15203
代码链接: https://github.com/NVlabs/SegFormer
Demo链接: https://www.bilibili.com/video/BV1MV41147Ko/
SETR 采用了 ViT 作为backbone,然后与CNN的decoders结合来扩大特征图的分辨率。但是ViT有以下缺点:
对于第一个缺点, (pyramid vision transformer)PVT 基于 ViT提出了金字塔结构。但是,当PVT结合Swin Transformer和Twins的时候,这些方法主要考虑的是Transformer encoder,忽略了对decoder的进一步改进和提升。
基于效率,精度和鲁棒性,作者提出了以下两个方向:
所提出的encoder避免插入位置信息,对于训练和测试不同分辨率的输入,不会影响性能。除此之外,hierarchical部分可以使得encoder产生高分辨率精细的特征图和低分辨率粗糙的特征图。其次,提出的lightweight ALL-MLP decoder是Transformer产生的特征,其中较低层的注意力往往停留在局部,而最高层的注意力是非局部的。通过融合不同层的信息,MLP encoder结合了局部和全局的注意力。
先来看一下整体的网络结构:
可以看到,网络主要分两个部分:
输入一张图片 H × W × 3 H \times W \times 3 H×W×3,首先会分成size为 4 × 4 4 \times 4 4×4的patches。使用小一点的patches size有益于dense prediction task。 然后使用这些patches作为输入,输入到hierarchical Transformer encoder里面来获得多层级的特征图,分别为原图的 { 1 / 4 , 1 / 8 , 1 / 16 , 1 / 32 } \{1/4, 1/8, 1/16, 1/32\} {1/4,1/8,1/16,1/32}大小。然后把这些多层级的特征图输入到ALL-MLP decoder中,预测大小为 H 4 × W 4 × N c l s \frac{H}{4} \times \frac{W}{4} \times N_{cls} 4H×4W×Ncls的segmentation mask,其中 N c l s N_{cls} Ncls是类别的数量。
Encoder由图示模块所堆叠起来: 包括Efficient Self-Attention,Mix-FFN和Overlap Patch Embedding三个模块。
Overlap Patch Embedding
给定一个大小为 H × W × 3 H \times W \times 3 H×W×3的输入,使用patch merging来获得多层级的feature map F i F_i Fi,对应的分辨率大小为: H 2 i − 1 × W 2 i + 1 × C i , i ∈ { 1 , 2 , 3 , 4 } \frac{H}{2^{i-1} } \times \frac{W}{2^{i+1} } \times C_i, \quad i \in \{1,2,3,4\} 2i−1H×2i+1W×Ci,i∈{1,2,3,4},其中通道数 C i + 1 C_{i+1} Ci+1大于 C i C_i Ci。 Overlapped Patch merging操作会把 N × N × 3 N \times N \times 3 N×N×3的输入变为 1 × 1 × C 1\times 1\times C 1×1×C的vector。不太了解的话,我们看以下源代码一些片段:
class MixVisionTransformer(nn.Module):
def __init__(self, ...):
super().__init__()
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4,
in_chans=in_chans, embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2,
in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2,
in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2,
in_chans=embed_dims[2], embed_dim=embed_dims[3])
class OverlapPatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
可以看到,Overlap patch merging操作主要是2D卷积,通过改变patch_size和stride把特征图进行缩放,形成了特征层级结构。ViT中使用的patch merging过程结合的是非重叠的patches,因此会导致无法保留这些patches之间的局部联系。而这里所用的是overlapping patch merging。通过patch_size和stride的不同达到效果
Efficient Self-Attention
encoder的主要计算瓶颈在self-attention层。在原始的multi-head self-attention过程中,每一个heads Q , K , V Q, K, V Q,K,V都是相同的维度 N × C N \times C N×C,其中 N = H × W N = H \times W N=H×W是序列的长度。所以self-attention定义为:
Attention(Q, K, V) = Softmax ( Q K T d h e a d ) V \text{Attention(Q, K, V)} = \text{Softmax}(\frac{QK^T}{\sqrt{d_{head}}})V Attention(Q, K, V)=Softmax(dheadQKT)V
这个过程的复杂度是 O ( N 2 ) O(N^2) O(N2),对于分辨率大的图像是难以承担的。所以作者提出了 sequence reduction process。也就是PVT中提出来的spatial reduction操作。
K ^ = R e s h a p e ( N R , C ⋅ R ) ( K ) K = L i n e a r ( C ⋅ R , C ) ( K ^ ) \hat{K} = Reshape(\frac{N}{R}, C\cdot R)(K) \\ K = Linear(C \cdot R, C) (\hat{K}) K^=Reshape(RN,C⋅R)(K)K=Linear(C⋅R,C)(K^)
简单来说,输入维度为 N × C N\times C N×C的K 首先会reshape成维度为 N R × ( C ⋅ R ) \frac{N}{R} \times (C\cdot R) RN×(C⋅R)维度。然后通过Linear变换,將 ( C ⋅ R ) (C \cdot R) (C⋅R)的维度变为 C C C。这样输出的 K K K维度为 N R × C \frac{N}{R}\times C RN×C,计算复杂度为 O ( N 2 R ) O(\frac{N^2}{R}) O(RN2)。作者把reduction ratio设置为 [ 64 , 16 , 4 , 1 ] [64,16,4,1] [64,16,4,1],分别对应stage-1到stage-4。
因此,在本文中 Q , K , V Q, K, V Q,K,V的维度分别为 N × C , N R × C , N R × C N \times C, \frac{N}{R} \times C, \frac{N}{R}\times C N×C,RN×C,RN×C。 因此 softmax的维度为: ( Q K T ) V (Q K^T) V (QKT)V,也就是 ( N × C ) × ( C × N R ) × ( N R × C ) = ( N × N R ) × ( N R × C ) = N × C (N \times C) \times (C \times \frac{N}{R}) \times ( \frac{N}{R} \times C) = (N \times \frac{N}{R}) \times (\frac{N}{R} \times C) = N \times C (N×C)×(C×RN)×(RN×C)=(N×RN)×(RN×C)=N×C。
下面来看一下代码节选:
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
Mix-FFN
ViT使用了position embedding来引入定位信息。但是,position embedding的大小是固定的,因此,当测试的时候输入分辨率与训练的时候的分辨率不一致的话,positional code就需要被插值,导致精度的下降。所以CPV在positional embedding上使用了3X3卷积,来实现数据驱动的一个positional encoding。本文提出的是:positional encoding对于语义分割是没有必要的。 作者引入的是Mix-FFN模块,通过在Feed-forward network(FFN)上直接使用3X3卷积,减弱了zero-padding会丢失一些定位信息的影响。
x o u t = MLP ( GELU ( Conv 3 × 3 ( MLP ( x i n ) ) ) ) + x i n x_{out} = \text{MLP}(\text{GELU}(\text{Conv}_{3\times 3}(\text{MLP}(x_{in}))))+x_{in} xout=MLP(GELU(Conv3×3(MLP(xin))))+xin
其中, x i n x_{in} xin是self-attention模块中的输出特征。作者实验表明,3X3卷积足够提供位置信息给Transformer。特别地,作者使用了depth-wise convolutions来减少参数量来提升效率。
代码节选:
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
Decoder是由MLP组成的。能够使用如此简单的decoder是因为hierarchical Transformer encoder比传统的CNN encoders有较大的有效感知域。
可以看到,Decoder分为四个步骤:
multi-level features F i F_i Fi 会输入到MLP层,使得通道的维度一样。
F ^ i = Linear ( C i , C ) ( F i ) , ∀ i \hat{F}_i = \text{Linear}(C_i, C)(F_i), \forall i F^i=Linear(Ci,C)(Fi),∀i
features map会进行上采样到原图的 1 4 \frac{1}{4} 41,并且拼接在一起。
F ^ i = Upsample ( W 4 × W 4 ) ( F ^ i ) , ∀ i \hat{F}_i = \text{Upsample}(\frac{W}{4} \times \frac{W}{4})(\hat{F}_i), \forall i F^i=Upsample(4W×4W)(F^i),∀i
再次使用MLP层把拼接的特征融合。
F = Linear ( 4 C , C ) ( Concat ( F ^ i ) ) , ∀ i F = \text{Linear}(4C, C)(\text{Concat}(\hat{F}_i)), \forall i F=Linear(4C,C)(Concat(F^i)),∀i
另外一个MLP层將融合的特征进行segmentation mask的预测。输入维度为 H 4 × W 4 × N c l s \frac{H}{4} \times \frac{W}{4} \times N_cls 4H×4W×Ncls
M = Linear ( C , N c l s ) ( F ) M = \text{Linear}(C, N_{cls})(F) M=Linear(C,Ncls)(F)
通过下面源码可知,第一步的Linear是FC层,Upsample是bilinear插值法。第三步的特征融合和第四步预测都是卷积操作。
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
class MLP(nn.Module):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class SegFormerHead(BaseDecodeHead):
"""
SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
"""
def __init__(self, feature_strides, **kwargs):
super(SegFormerHead, self).__init__(input_transform='multiple_select', **kwargs)
assert len(feature_strides) == len(self.in_channels)
assert min(feature_strides) == feature_strides[0]
self.feature_strides = feature_strides
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
decoder_params = kwargs['decoder_params']
embedding_dim = decoder_params['embed_dim']
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
self.linear_fuse = ConvModule(
in_channels=embedding_dim*4,
out_channels=embedding_dim,
kernel_size=1,
norm_cfg=dict(type='SyncBN', requires_grad=True)
)
self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
def forward(self, inputs):
x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
c1, c2, c3, c4 = x
############## MLP decoder on C1-C4 ###########
n, _, h, w = c4.shape
_c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
# F.interpolate(input, size, scale_factor, mode, align_corners)
_c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
_c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
x = self.dropout(_c)
x = self.linear_pred(x)
return x
不同大小的网络结构如下图所示。
所对应的参数分别为:
提出了一个结合了positional-encoding-free,hierarchical Transformer encoder和lightweight ALL-MLP decoder语义分割网络。改进了ViT和应用了PVT。