这篇文章属于是将Transfomer用于语义分割的早期尝试。编码器和解码器都用的是ViT的transfomer块。创新点主要位于解码器部分,作者构思了两种解码器让从transfomer编码器里面出来的特征图经过解码器得到最终分割图。
论文地址:论文PDF地址
代码地址:github代码地址
图像分割在单独的图像patch级别会出现模糊不清的情况,分割需要上下文信息才能有较好的效果。本文介绍的Segmenter可以在网络中进行全局上下文建模。
Segmenter考虑了了两种解码器,一种是较为简单的线性解码器,另一种是mask decoder,mask decoder取得了较好的效果。
作者指出Segmenter为了用Transfomer做分割,编码器使用了ViT,没有大的改变,而解码器考虑了两种,一种是较为简单的线性的逐点映射的利用class score解码器,第二种是用class数量的可学习token与编码器输出token联合处理生成class mask。
。。。
编码器部分不用过多描述,Segmenter的编码器使用了L层,每一层包括一个多头self-attention,两个point-wise MLP,在块前有LN,即layer norm.,块后有残差结构。
MSA=multi-headed self-attention
LN=layer norm
为了从编码器生成的特征图(NxD)中得到分割结果,线性解码器用了一个point-wise 的linear layer,这样特征图的shape从NxD到NxK
N是patch数量,K是class num。然后再reshape成H/P ×W/P ×K再用双线性插值法将特征图的宽和高上采样到H,W,于是shape变为HxWxK,再用一次softmax得到最终分割图。这也算是全MLP解码层,但毕竟是早期论文,与后面的Segfomer相比,Segfomer也是轻量化的全MLP解码器,但是Segfomer将编码器内多层特征图插值到相同规格再拼接,能更好的结合上下文信息。而本文的线性解码器略显粗糙。不过本文采用的是设计更为复杂的mask decoder.
还是要考虑从编码器NxD的特征图中获得分割结果的问题,mask transfomer放弃了全MLP的解码器,采用transfomer的结构。解码器实际上就是M层编码器组成。现在引入了K个可学习的class embedding,shape是KxD,随机初始化。然后按接下来几个步骤来获得分割图。假设从编码器里出来的特征图为Z,class embeeding是C,忽略batchsize可以看作:
1.拼接Z与C。这样shape的变化是(NxC),(KxC) ——> ((N+K), C)
2.拼接图进入transfomer解码器模块,进行M层的transfomer模块运算。这样shape没有改变。
3.拆开特征图,这样shape的变化是((N+K), C)——> (NxC),(KxC)
4.Mask = ZCT.这样shape的变化是(NxC)*(CxK) ——> (NxK)
5.reshape(Mask). 这样shape的变化是(NxK) ——> (H/P ×W/P ×K)
6.双线性插值上采样。这样shape的变化是(H/P ×W/P ×K)——> (H ×W ×K)
7.softmax
mmsegmentation中的实现与论文有简单的不同,送入解码器的特征图shape为(B,C,H,W)
def forward(self, inputs):
x = self._transform_inputs(inputs)
b, c, h, w = x.shape
x = x.permute(0, 2, 3, 1).contiguous().view(b, -1, c)#(B,C,H,W)->(B,H,W,C)->(B,HW,C)
x = self.dec_proj(x)
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)#clas embeeding(B, K, C)
x = torch.cat((x, cls_emb), 1)#拼接Z与C,(B, HW+K, C)
for layer in self.layers:
x = layer(x)#将混合特征图送入L层编码器
x = self.decoder_norm(x)
patches = self.patch_proj(x[:, :-self.num_classes])#patch部分特征图进行一次linear,(B, HW, C)
cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])#class token部分特征图进行一次linear,(B, K, C)
patches = F.normalize(patches, dim=2, p=2)#归一化
cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)#归一化
masks = patches @ cls_seg_feat.transpose(1, 2)#点积,shape(B, HW, C)*(B,C,K) ——> (B, HW, K)
masks = self.mask_norm(masks)
masks = masks.permute(0, 2, 1).contiguous().view(b, -1, h, w)#(B, HW, K)->(B,K,HW)->(B,K,H,W)
return masks
在实验中,值得关注的是使用线性解码器和mask解码器的效果区别。从上图可以看出在模型大小,patch size相同的情况下,mask解码器的效果比线性解码器的效果更好。