【SOD论文阅读笔记】Visual Saliency Transformer

【SOD论文阅读笔记】Visual Saliency Transformer

    • 一、摘要
      • Motivation:
      • Method:
      • Experimental results
    • 二、Introduction
      • 当前最先进的方法以CNN结构为主
      • CNN结构的弊端
      • 引出Transformer
      • 本文中
      • contributions
    • 三、Visual Saliency Transformer
      • Transformer Encoder(T2t_vit_t_14)
      • Transformer Convertor
      • Multi-task Transformer Decoder

一、摘要

【SOD论文阅读笔记】Visual Saliency Transformer_第1张图片

Motivation:

现有的SOTA显著性检测方法在很大程度上依赖于基于CNN的网络。可替代地,我们从卷积free的sequence-to-sequence的角度重新考虑此任务,并通过建模长期依赖关系来预测显著性,而这不能通过卷积来实现。

这篇论文的出发点就是利用transformer来创新,并且这篇文章是纯transformer(convolution-free),所以摘要中从transformer和CNN的最大的不同出发来写motivation——即transformer对比CNN来说,是sequence-to-sequence结构的,且更有利于对长期依赖关系建模。

Method:

提出基于纯变压器的模型,即视觉显著性变压器 (VST),用于RGB和RGBD的显著性检测。

  • 以图像补丁为输入,并利用transformer在图像补丁之间传播全局上下文
  • 与视觉变压器 (ViT) 中使用的常规结构不同,我们利用多级token融合,并在变压器框架下提出了一种新的token上采样方法,以获得高分辨率的检测结果。
  • 我们还开发了基于token的多任务解码器,通过引入与任务相关的token和新颖的补丁-任务-注意力机制,同时执行显着性和边界检测。

先解释一下图像补丁。由于transormer是从NLP任务传到CV领域的,在NLP的机器翻译任务中,输入的是一个个单词,所以,把transformer移植到图像任务时,为了与其输入结构保持一致,会把图像切割成不重叠的补丁序列(可以想像一下把一张图片切割成九宫格/N宫格,每一个宫格就是一个补丁)。

再解释一下token。刚刚的图像补丁就可以被称之为一个token,它属于patch token。patch token输入到transformer中后,经过处理得到的feature也可以成为token。此外,transformer中还有一种class token,它本质上就是一个可训练的向量,通常在分类任务中直接通过这个Class token来判断类别。

这篇论文里有一个任务相关的token(task-related tokens),其实相当于tokens的一个头部,代表这个tokens是用于做什么任务的。这是因为,这篇论文提出的是多任务模型,输出的是 显著映射 和 边缘映射,本意是借助边缘的监督提升其显著映射的准确性。

Experimental results

实验结果表明,我们的模型在RGB和RGBD SOD基准数据集上都优于现有方法。

二、Introduction

当前最先进的方法以CNN结构为主

它们通常采用编码器-解码器架构,其中编码器将输入图像编码为多级特征,解码器将提取的特征集成以预测最终的显着性图。

  • RGB-SOD,旨在检测吸引人们眼睛的物体,并可以帮助许多视觉任务。
    • 各种注意力模型,多尺度特征集成方法和多任务学习框架
  • RBGD-SOD,则多了来自深度数据的额外空间结构信息。
    • 各种模态融合方法,如特征融合,知识蒸馏,动态卷积,注意力模型 ,图神经网络 。

CNN结构的弊端

所有方法在学习全局远程依赖方面受到限制

长期以来,全局上下文 和全局对比度 对于显著性检测至关重要。然而,由于cnn在局部滑动窗口中提取特征的内在限制,以前的方法很难利用关键的全局线索。

尽管一些方法利用全连接层,全局池化和非本地模块来合并全局上下文,但它们仅在某些层中这样做,并且基于CNN的体系结构保持不变。

引出Transformer

最近,提出了Transformer用于机器翻译的单词序列之间的全局远程依赖关系。

Transformer的核心思想是自注意机制,它利用query-key的相关性来关联序列中的不同位置。Transformer在编码器和解码器中多次堆叠自注意层,因此可以对每一层中的长距离依赖进行建模。因此,将变压器引入SOD是很自然的,一路利用模型中的全局线索。

本文中

我们从新的序列到序列的角度重新考虑SOD,并基于纯变压器开发了一种新颖的RGB和rgb-d SOD统一模型,称为视觉显着性变压器。

最近提出的ViT模型 [12,74],将每个图像划分为补丁,并在补丁序列上采用变压器模型。然后,变压器在图像补丁之间传播长距离依赖,而无需使用卷积。

然而,将ViT应用于SOD并不简单,存在两大问题:

  • 1.关于密集预测: 如何基于纯变压器执行密集预测任务仍然是一个悬而未决的问题。
    - 我们通过引入与任务相关的token来设计基于token的变压器解码器从而学习决策嵌入。然后,我们提出了一种新颖的补丁-任务-注意力机制来生成密集预测结果,这为在密集预测任务中使用transformer提供了新的范例。
    - 在以前的SOD模型的激励下,利用边界检测来提高SOD性能,我们构建了一个多任务解码器,通过引入显著性token和边界token,同时进行显著性和边界检测。该策略通过简单地学习与任务相关的token来简化多任务预测工作流程,从而大大降低了计算成本,同时获得了更好的结果。
  • 2.关于高分辨率:ViT通常将图像标记为非常粗糙的大小。如何使ViT适应SOD的高分辨率预测需求还不清楚。
    - 受tokens-to-tokens (T2T) 转换 [74] 的启发,该转换减少了tokens的长度,我们提出了一种新的反向T2T转换,通过将每个tokens扩展为多个子tokens来向上采样tokens。然后,我们逐步对补丁tokens进行采样,并将其与低级token融合,以获得最终的全分辨率显着性图。此外,我们还使用交叉模态transformer来深入探索rgb-d SOD的多模态信息之间的相互作用。

在RGB和RGBD数据上,以有可比性的数量的参数和计算成本,优于现有的最先进的SOD方法

contributions

  • 以序列to序列建模的新视角,设计了一种基于纯变压器架构的RGB和rgb-d SOD的新型统一模型。
  • 设计了一种多任务变压器解码器,通过引入任务相关的token和补丁-任务-注意力来联合进行显著性和边界检测
  • 一种新的基于transformer的token上采样方法
  • state-of-the-art结果

三、Visual Saliency Transformer

【SOD论文阅读笔记】Visual Saliency Transformer_第2张图片

我们为RGB和RGBD SOD提出的VST模型的整体架构。它首先使用编码器从输入的图像补丁序列中生成多级tokens。然后,采用转换器将补丁tokens转换为解码器空间,并对rgb-d数据进行跨模态信息融合。最后,解码器通过我们提出的与任务相关的token以及补丁-任务-注意机制同时预测显着图和边界图。还提出了一种RT2T转换,以逐步上采样补丁tokens。虚线表示rgb-d SOD的专用成分。

  • 主要组件包括3部分:基于T2T-ViT的变压器encoder (T2t_vit_t_14),用于将补丁tokens从编码器空间转换到解码器空间的变压器转换器 (Transformer),以及多任务变压器decoder (token_Transformer, Decoder)。
class ImageDepthNet(nn.Module):
    def __init__(self, args):
        super(ImageDepthNet, self).__init__()
        # VST Encoder
        self.rgb_backbone = T2t_vit_t_14(pretrained=True, args=args)
        # VST Convertor
        self.transformer = Transformer(embed_dim=384, depth=4, num_heads=6, mlp_ratio=3.)
        # VST Decoder
        self.token_trans = token_Transformer(embed_dim=384, depth=4, num_heads=6, mlp_ratio=3.)
        self.decoder = Decoder(embed_dim=384, token_dim=64, depth=2, img_size=args.img_size)

    def forward(self, image_Input):
        B, _, _, _ = image_Input.shape
        # image_Input [B, 3, 224, 224]
        # VST Encoder
        rgb_fea_1_16, rgb_fea_1_8, rgb_fea_1_4 = self.rgb_backbone(image_Input)
        # rgb_fea_1_16 [B, 14*14, 384]
        # rgb_fea_1_8 [B, 28*28, 384]
        # rgb_fea_1_4 [B, 56*56, 384]
        # VST Convertor
        rgb_fea_1_16 = self.transformer(rgb_fea_1_16)
        # rgb_fea_1_16 [B, 14*14, 384]
        # VST Decoder
        saliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens = self.token_trans(rgb_fea_1_16)
        # saliency_fea_1_16 [B, 14*14, 384]
        # fea_1_16 [B, 1 + 14*14 + 1, 384]
        # saliency_tokens [B, 1, 384]
        # contour_fea_1_16 [B, 14*14, 384]
        # contour_tokens [B, 1, 384]
        outputs = self.decoder(saliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens, rgb_fea_1_8, rgb_fea_1_4)
		# [mask_1_16, mask_1_8, mask_1_4, mask_1_1],[contour_1_16, contour_1_8, contour_1_4, contour_1_1]
		# mask_1_16/contour_1_16 [B, 1, 14, 14]
		# mask_1_1/contour_1_1 [B, 1, 224, 224]
        return outputs

Transformer Encoder(T2t_vit_t_14)

以下是Transformer Encoder的整体框架

class T2T_ViT(nn.Module):
    def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=nn.LayerNorm):
        super().__init__()
     
        self.tokens_to_token = T2T_module(img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.tokens_to_token.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False)

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
    
    def forward(self, x):
        B = x.shape[0]
        x, x_1_8, x_1_4 = self.tokens_to_token(x)
		#[B,196,384],[B, 28×28, 384],[B, 56×56, 384]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        #[1,1,384]->[B,1,384]
        x = torch.cat((cls_tokens, x), dim=1)
        #cat([B,1,384],[B,196,384])->[B,197,384]
        x = x + self.pos_embed
        #[B,197,384]+[1,197,384]->[B,197,384]

        # T2T-ViT backbone
        for blk in self.blocks:
            x = blk(x)
		#[B,197,384]
        x = self.norm(x)
  		#[B,197,384]
        return x[:, 1:, :], x_1_8, x_1_4

可以看出,Transformer Encoder由一个T2T模块和一些后处理步骤构成。
输入:(B,3,224,224)
输出:由于我们做的是像素级分类而不是对象级分类,所以输出了多级特征:fea_1_16 [B, 14×14, 384],fea_1_8 [B, 28×28, 384],fea_1_4 [B, 56×56, 384]。

T2T模块:待会儿详细介绍。
后处理步骤:

  1. 首先,x被concat了一个1维的全零分类tokens,由于其被初始化为0,所以没什么好介绍的。

x = torch.cat((cls_tokens, x), dim=1)

  1. 其次,x被add了一个shape与其shape相同的正弦位置tokens

self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False)
x = x + self.pos_embed

这里对self.pos_embed的初始化是有讲究的,用到的是《Attention is all you need》中提出的正弦位置,参数就是要生成的shape的参数。

3.最后,重复经过depth个Blocks。这里depth设置为14。
每个Block:

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return 

该过程就是不断Attention、MLP的迭代过程,且输出与输入的shape保持一致[B, 197, 384]。
Attention就是普通多头attention(Linear[通道数扩大三倍]、分为qkv、softmax(q*k)*v,最后再Linear[不改变通道数])

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
    def forward(self, x):
        B, N, C = x.shape
        #[B,197,384]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # self.qkv(x):[B,197*3,384]
        #.reshape(B, N, 3, self.num_heads, C // self.num_heads): [B,197,3,6,64]
        #.permute(2, 0, 3, 1, 4): [3,B,6,197,64]
        q, k, v = qkv[0], qkv[1], qkv[2]
		#[B,6,197,64]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # k.transpose(-2, -1): [B,6,64,197]
        # q @ k.transpose(-2, -1):[B,6,197,197]
        attn = attn.softmax(dim=-1)
        # [B,6,197,197]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # attn @ v : [B,6,197,197] * [B,6,197,64] -> [B,6,197,64]
        # .transpose(1, 2) : [B,197,6,64]
        # .reshape(B, N, C) : [B,197,384]
        x = self.proj(x)
        #[B,197,384]
        return x

MLP就是(Linear[通道数扩大3倍]、Gelu激活、Linear[通道数恢复])

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
    def forward(self, x):
        x = self.fc1(x)
        #[B,197,384*3]
        x = self.act(x)
        x = self.fc2(x)
        #[B,197,384]
        return x

  • Tokens to Token模块

给定一系列长度为l的补丁tokens T’,T2T-ViT会连续堆叠T2T模块。
T2T模块是由重构步骤(a re-structurization step: 多头自注意力+多层感知机)和软拆分步骤(a soft split step:unfold)组成的,对T’中的局部结构信息进行建模,并获得新的token序列。
T2T变换可以多次迭代进行。在每次的迭代中,重构步骤首先将以前的token嵌入转换为新的嵌入,并且还在所有token内集成了远程依赖关系。然后,软拆分操作将每个k × k邻居中的token聚合成一个新token,该token准备用于下一层。
此外,当设置s

个人觉得这里的tokens-to-tokens模块更应该叫做features-to-features模块,因为这个模块的输入是二维的features,进入模块后会先软分割(unfold)变形为1维的向量,即tokens串,然后self-attention,最后再reshape成二维的特征图。

  • 重构步骤 a re-structurization step
    tokens T’会首先使用一个transformer层,获得一个新的tokens T ∈ R l × c T∈R^{l×c} TRl×c
    transformer层: MSA 多头自注意力+MLP多层感知机
    之后,T会被reshape为2维图像I∈Rh×w×c,从而恢复空间结构
    【SOD论文阅读笔记】Visual Saliency Transformer_第3张图片
  • 软拆分步骤 a soft split step
    与ViT不同,T2T-ViT中采用的重叠补丁拆分在相邻补丁中引入了局部对应关系,从而带来了空间先验。
    I ∈ R h × w × c I∈R^{h×w×c} IRh×w×c首先会给边界补上p个0,之后被拆分为重叠区域为s的k×k个补丁块。
    然后图像补丁块会被展开成一系列tokens T o ∈ R l o × c k 2 T_{o}∈ R^{l_{o}×ck^{2}} ToRlo×ck2
    在这里插入图片描述
  • 具体设置:我们按照 [74] 首先将输入图像软分割成补丁,然后两次迭代T2T模块。在三个软拆分步骤中,补丁大小设置为k = [7,3,3],重叠映射设置为s = [3,1,1],填充大小设置为p = [2,1,1]。因此,我们可以获得多级tokensT1 ∈ Rl1 × c,T2 ∈ Rl2 × c和T3 ∈ Rl3 × c。给定输入图像的宽度和高度分别为H和W,则l1 = H /4 × W/ 4,l2 = H/8 × W/8,l3 = H/16 × W/16。我们遵循 [74] 设置c = 64,并使用t3上的线性投影层将其嵌入尺寸从c转换为d = 384。
class T2T_module(nn.Module):
    """
    Tokens-to-Token encoding module
    """
    def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64):
        super().__init__()
        if tokens_type == 'transformer':
            print('adopt transformer encoder for tokens-to-token')
            self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
            self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))

            self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)
            self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)
            self.project = nn.Linear(token_dim * 3 * 3, embed_dim)

        elif tokens_type == 'performer':
            ……
        elif tokens_type == 'convolution':  # just for comparison with conolution, not our model
            ……
        self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2))  # there are 3 sfot split, stride are 4,2,2 seperately
    def forward(self, x):
    	#Input[B,3,224,224]
        # step0: soft split
        x = self.soft_split0(x).transpose(1, 2)
        # (224 + 2*2 - 7) / 4 + 1 =  56
		# self.soft_split0(x):[B,147=7*7*3,56*56]
        # .transpose(1, 2):[B, 56*56, 147=7*7*3]
        # iteration1: restricturization/reconstruction
        x_1_4 = self.attention1(x)
        # [B, 56*56, 64]
        B, new_HW, C = x_1_4.shape
        x = x_1_4.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
        #[B,64,56,56]
        
        # iteration1: soft split
        x = self.soft_split1(x).transpose(1, 2)
		# self.soft_split1(x) : [B,576=3*3*64,28*28]
		#.transpose(1, 2) : [B,28*28,576]
        # iteration2: restricturization/reconstruction
        x_1_8 = self.attention2(x)
        #[B,28*28,64]
        B, new_HW, C = x_1_8.shape
        x = x_1_8.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
        #[B,64,28,28]
        
        # iteration2: soft split
        x = self.soft_split2(x).transpose(1, 2)
        #[B,14*14,576=3*3*64]
        # final tokens
        x = self.project(x)
		#[B,196,384]
		
        return x, x_1_8, x_1_4

其中,Token_transformer的结构与上述Block特别相似,都是由Attention和MLP组成。
区别:
Attention中:to_qkv时不再设置为原有通道数的3倍,而是64的3倍,从而实现了通道数的改变;
不再设置multi-head;最终残差相加的不是原来的输入(因为通道数变了,没办法直接加),而是v。

MLP中:两次Linear的通道数没有改变。

Encoder with T2T-ViT Backbone

  • 最后的token序列T3与编码2D位置信息的正弦位置嵌入 [61] add起来。然后,使用 L ε L^{\varepsilon} Lε transformer层对T3之间的长期依赖进行建模,以提取强大的补丁token嵌入 T ε ∈ R l 3 × d T^{\varepsilon} ∈ R^{l_{3} × d} TεRl3×d
  • SOD:应用1个transformer encoder将RGB图像编码为补丁tokens T r ε ∈ R l 3 × d T_{r}^{\varepsilon} ∈ R^{l_{3} × d} TrεRl3×d
  • RSOD:应用双流transformer encoder,将深度图像以同样的方式编码为补丁tokens T d ε ∈ R l 3 × d T_{d}^{\varepsilon} ∈ R^{l_{3} × d} TdεRl3×d

Transformer Convertor

我们在变压器编码器和解码器之间插入一个转换器模块,以将编码器补丁tokensTE ∗ 从编码器空间转换到解码器空间,从而获得转换后的补丁tokensTc ∈ Rl3 × d。从输出的shape可以看出,这里特征的形状并没有改变。

  • RGB-D Convertor
  • RGB Convertor
  • transforner层:多个Block+layernorm
    Block:
    x = x+self-attention(layernorm(x))
    x = x+mlp(layernorm(x))

与刚刚Transformer Encoder中最后进行的多个Block的完全一样,这次设置了4个Block,加上刚刚的14个,相当于让fea_1_16经历了18次Attention+MLP。

class TransformerEncoder(nn.Module):
    def __init__(self, depth, num_heads, embed_dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        self.blocks = nn.ModuleList([
                 Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,norm_layer=norm_layer)
                 for i in range(depth)])
        self.rgb_norm = norm_layer(embed_dim)
    def forward(self, rgb_fea):
        for block in self.blocks:
            rgb_fea = block(rgb_fea)
        rgb_fea = self.rgb_norm(rgb_fea)
        return 

这里不改变输入的shape,输入该模块的是fea_1_16[B,14×14,384],输出的仍然是fea_1_16[B,14×14,384]。

Multi-task Transformer Decoder

这个模块在论文中的思路已经在思维导图中写了,以下按照代码思路串一遍。
刚刚在总框架代码中写了,decoder实际上包含了两部分:token_Transformer, Decoder。

def __init__(self, args):
		……
		# VST Decoder
        self.token_trans = token_Transformer(embed_dim=384, depth=4, num_heads=6, mlp_ratio=3.)
        self.decoder = Decoder(embed_dim=384, token_dim=64, depth=2, img_size=args.img_size)
def forward(self, image_Input):
		……
		# VST Decoder
        saliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens = self.token_trans(rgb_fea_1_16)
        # saliency_fea_1_16 [B, 14*14, 384]
        # fea_1_16 [B, 1 + 14*14 + 1, 384]
        # saliency_tokens [B, 1, 384]
        # contour_fea_1_16 [B, 14*14, 384]
        # contour_tokens [B, 1, 384]
        outputs = self.decoder(saliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens, rgb_fea_1_8, rgb_fea_1_4)
		# [mask_1_16, mask_1_8, mask_1_4, mask_1_1],[contour_1_16, contour_1_8, contour_1_4, contour_1_1]
		# mask_1_16/contour_1_16 [B, 1, 14, 14]
		# mask_1_1/contour_1_1 [B, 1, 224, 224]
        return outputs

首先看 token_Transformer,
这部分主要引入了与任务相关的token以及patch-任务-注意力。
它的输入是fea_1_16[B,14×14,384],输出了5部分:

  • 代表saliency任务的任务tokens: saliency_tokens [B, 1, 384]
  • 代表saliency任务的特征tokens:saliency_fea_1_16 [B, 14×14, 384]
  • 代表边缘任务的任务tokens: contour_tokens [B, 1, 384]
  • 代表边缘任务的特征tokens:contour_fea_1_16 [B, 14*14, 384]
  • 总的特征tokens:fea_1_16 [B, 1 + 14×14 + 1, 384]
class token_Transformer(nn.Module):
    def __init__(self, embed_dim=384, depth=14, num_heads=6, mlp_ratio=3.):
        super(token_Transformer, self).__init__()

        self.norm = nn.LayerNorm(embed_dim)
        self.mlp_s = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim),
        )
        self.saliency_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.contour_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.encoderlayer = token_TransformerEncoder(embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio)
        self.saliency_token_pre = saliency_token_inference(dim=embed_dim, num_heads=1)
        self.contour_token_pre = contour_token_inference(dim=embed_dim, num_heads=1)

    def forward(self, rgb_fea):
        B, _, _ = rgb_fea.shape
        fea_1_16 = self.mlp_s(self.norm(rgb_fea))   # [B, 14*14, 384]
        saliency_tokens = self.saliency_token.expand(B, -1, -1) # [B, 1, 384]
        fea_1_16 = torch.cat((saliency_tokens, fea_1_16), dim=1) # [B, 1+14*14, 384]

        contour_tokens = self.contour_token.expand(B, -1, -1) # [B, 1, 384]
        fea_1_16 = torch.cat((fea_1_16, contour_tokens), dim=1) #[B, 1 + 14*14 + 1, 384]

        fea_1_16 = self.encoderlayer(fea_1_16)
        # fea_1_16 [B, 1 + 14*14 + 1, 384]
        
        saliency_tokens = fea_1_16[:, 0, :].unsqueeze(1) # [B, 1, 384]
        contour_tokens = fea_1_16[:, -1, :].unsqueeze(1) # [B, 1, 384]

        saliency_fea_1_16 = self.saliency_token_pre(fea_1_16) # [B, 14*14, 384]
        contour_fea_1_16 = self.contour_token_pre(fea_1_16) # [B, 14*14, 384]
        return saliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16, 

这里,token_TransformerEncoder与刚刚的Transformer Convertor设置完全一样,仍然是4个多头注意力Attention+MLP组成的blocks。

重点介绍一下saliency_token_inference和contour_token_inference。
它们俩的输入都是总的特征tokens fea_1_16 [B, 1 + 14×14 + 1, 384],输出的是分别代表saliency和边缘的特征tokens: [B, 14×14, 384] 。

saliency_token_inference:

class saliency_token_inference(nn.Module):
    def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()

        self.norm = nn.LayerNorm(dim)
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sigmoid = nn.Sigmoid()

    def forward(self, fea):
        B, N, C = fea.shape
        x = self.norm(fea)
        T_s, F_s = x[:, 0, :].unsqueeze(1), x[:, 1:-1, :]
        # T_s [B, 1, 384]  F_s [B, 14*14, 384]

        q = self.q(F_s).reshape(B, N-2, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        #[B,196,1,384]->[B,1,196,384]
        k = self.k(T_s).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        #[B,1,1,384]->[B,1,1,384]
        v = self.v(T_s).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
		#[B,1,1,384]->[B,1,1,384]
        attn = (q @ k.transpose(-2, -1)) * self.scale
		#[B,1,196,384]*[B,1,384,1]->[B,1,196,1]
        attn = self.sigmoid(attn)
        attn = self.attn_drop(attn)

        infer_fea = (attn @ v).transpose(1, 2).reshape(B, N-2, C)
        #[B,1,196,1]*[B,1,1,384]->[B,1,196,384]->[B,196,1,384]->[B,196,384]
        infer_fea = self.proj(infer_fea)
        #[B,196,384]
        infer_fea = self.proj_drop(infer_fea)

        infer_fea = infer_fea + fea[:, 1:-1, :]
        #[B,196,384]
        return infer_fea

contour_token_inference与saliency_token_inference一样,只不过在取任务token时,取的是-1位。

接下来介绍Decoder。
这部分主要是反T2T的上采样,以及多级特征融合。
输入的是7部分,包括刚刚第一部分的decoder的输出,以及 encoder输出的fea_1_8和 fea_1_4。

  • saliency_fea_1_16 [B, 14*14, 384]
  • fea_1_16 [B, 1 + 14*14 + 1, 384]
  • saliency_tokens [B, 1, 384]
  • contour_fea_1_16 [B, 14*14, 384]
  • contour_tokens [B, 1, 384]
  • fea_1_8 [B, 28*28, 64]
  • fea_1_4 [B, 56*56, 64]
class Decoder(nn.Module):
    def __init__(self, embed_dim=384, token_dim=64, depth=2, img_size=224):

        super(Decoder, self).__init__()

        self.norm = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, token_dim),
        )

        self.norm_c = nn.LayerNorm(embed_dim)
        self.mlp_c = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, token_dim),
        )
        self.img_size = img_size
        # token upsampling and multi-level token fusion
        self.decoder1 = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=True)
        self.decoder2 = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=True)
        self.decoder3 = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=1, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2), fuse=False)
        self.decoder3_c = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=1, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2), fuse=False)

        # token based multi-task predictions
        self.token_pre_1_8 = token_trans(in_dim=token_dim, embed_dim=embed_dim, depth=depth, num_heads=1)
        self.token_pre_1_4 = token_trans(in_dim=token_dim, embed_dim=embed_dim, depth=depth, num_heads=1)

        # predict saliency maps
        self.pre_1_16 = nn.Linear(token_dim, 1)
        self.pre_1_8 = nn.Linear(token_dim, 1)
        self.pre_1_4 = nn.Linear(token_dim, 1)
        self.pre_1_1 = nn.Linear(token_dim, 1)
        # predict contour maps
        self.pre_1_16_c = nn.Linear(token_dim, 1)
        self.pre_1_8_c = nn.Linear(token_dim, 1)
        self.pre_1_4_c = nn.Linear(token_dim, 1)
        self.pre_1_1_c = nn.Linear(token_dim, 1)

        for m in self.modules():
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.xavier_uniform_(m.weight),
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif classname.find('Linear') != -1:
                nn.init.xavier_uniform_(m.weight),
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif classname.find('BatchNorm') != -1:
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, saliency_fea_1_16, token_fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens, rgb_fea_1_8, rgb_fea_1_4):
        # saliency_fea_1_16 [B, 14*14, 384]
        # contour_fea_1_16 [B, 14*14, 384]
        # token_fea_1_16  [B, 1 + 14*14 + 1, 384] (contain saliency token and contour token)
        # saliency_tokens [B, 1, 384]
        # contour_tokens [B, 1, 384]
        # rgb_fea_1_8 [B, 28*28, 64]
        # rgb_fea_1_4 [B, 56*56, 64]

        B, _, _, = token_fea_1_16.size()

        saliency_fea_1_16 = self.mlp(self.norm(saliency_fea_1_16))
        # saliency_fea_1_16 [B, 14*14, 64]
        mask_1_16 = self.pre_1_16(saliency_fea_1_16)
        # mask_1_16 [B,14*14,1]
        mask_1_16 = mask_1_16.transpose(1, 2).reshape(B, 1, self.img_size // 16, self.img_size // 16)
        # mask_1_16 [B,1,14,14]

        contour_fea_1_16 = self.mlp_c(self.norm_c(contour_fea_1_16))
        # contour_fea_1_16 [B, 14*14, 64]
        contour_1_16 = self.pre_1_16_c(contour_fea_1_16)
        contour_1_16 = contour_1_16.transpose(1, 2).reshape(B, 1, self.img_size // 16, self.img_size // 16)

        # 1/16 -> 1/8
        # reverse T2T and fuse low-level feature
        fea_1_8 = self.decoder1(token_fea_1_16[:, 1:-1, :], rgb_fea_1_8)

        # token prediction
        saliency_fea_1_8, contour_fea_1_8, token_fea_1_8, saliency_tokens, contour_tokens = self.token_pre_1_8(fea_1_8, saliency_tokens, contour_tokens)

        # predict saliency maps and contour maps
        mask_1_8 = self.pre_1_8(saliency_fea_1_8)
        mask_1_8 = mask_1_8.transpose(1, 2).reshape(B, 1, self.img_size // 8, self.img_size // 8)

        contour_1_8 = self.pre_1_8_c(contour_fea_1_8)
        contour_1_8 = contour_1_8.transpose(1, 2).reshape(B, 1, self.img_size // 8, self.img_size // 8)

        # 1/8 -> 1/4
        fea_1_4 = self.decoder2(token_fea_1_8[:, 1:-1, :], rgb_fea_1_4)

        # token prediction
        saliency_fea_1_4, contour_fea_1_4, token_fea_1_4, saliency_tokens, contour_tokens = self.token_pre_1_4(fea_1_4, saliency_tokens, contour_tokens)

        # predict saliency maps and contour maps
        mask_1_4 = self.pre_1_4(saliency_fea_1_4)
        mask_1_4 = mask_1_4.transpose(1, 2).reshape(B, 1, self.img_size // 4, self.img_size // 4)

        contour_1_4 = self.pre_1_4_c(contour_fea_1_4)
        contour_1_4 = contour_1_4.transpose(1, 2).reshape(B, 1, self.img_size // 4, self.img_size // 4)

        # 1/4 -> 1
        saliency_fea_1_1 = self.decoder3(saliency_fea_1_4)
        contour_fea_1_1 = self.decoder3_c(contour_fea_1_4)

        mask_1_1 = self.pre_1_1(saliency_fea_1_1)
        mask_1_1 = mask_1_1.transpose(1, 2).reshape(B, 1, self.img_size // 1, self.img_size // 1)

        contour_1_1 = self.pre_1_1_c(contour_fea_1_1)
        contour_1_1 = contour_1_1.transpose(1, 2).reshape(B, 1, self.img_size // 1, self.img_size // 1)

        return [mask_1_16, mask_1_8, mask_1_4, mask_1_1], [contour_1_16, contour_1_8, contour_1_4, contour_1_1]

核心在于decoder_module模块。
我们用出现的第一个decoder_module模块为例,它的参数设置为:

self.decoder1 = decoder_module(dim=384, token_dim=64, img_size=224, ratio=8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=True)

输入的是token_fea_1_16的中间段(即去掉两头的任务token,留下feature token)[B,196,384]
以及rgb_fea_1_8 [B, 28*28, 64]

fea_1_8 = self.decoder1(token_fea_1_16[:, 1:-1, :], rgb_fea_1_8)

下面是decoder_module

class decoder_module(nn.Module):
    def __init__(self, dim=384, token_dim=64, img_size=224, ratio=8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=True):
        super(decoder_module, self).__init__()
        self.project = nn.Linear(token_dim, token_dim * kernel_size[0] * kernel_size[1])
        self.upsample = nn.Fold(output_size=(img_size // ratio,  img_size // ratio), kernel_size=kernel_size, stride=stride, padding=padding)
        self.fuse = fuse
        if self.fuse:
            self.concatFuse = nn.Sequential(
                nn.Linear(token_dim*2, token_dim),
                nn.GELU(),
                nn.Linear(token_dim, token_dim),
            )
            self.att = Token_performer(dim=token_dim, in_dim=token_dim, kernel_ratio=0.5)
            # project input feature to 64 dim
            self.norm = nn.LayerNorm(dim)
            self.mlp = nn.Sequential(
                nn.Linear(dim, token_dim),
                nn.GELU(),
                nn.Linear(token_dim, token_dim),
            )
    def forward(self, dec_fea, enc_fea=None):
        if self.fuse:
            # from 384 to 64
            #[B,14*14,384]->[B,14*14,64]
            dec_fea = self.mlp(self.norm(dec_fea))
        # [1] token upsampling by the proposed reverse T2T module
        #由于要扩大feature的面积,所以要改变通道
        #[B,14*14,64]->[B,14*14,64*3*3]
        dec_fea = self.project(dec_fea)
        
        #[B,14*14,64*3*3]->[B,64*3*3,14*14]->[B,64,28,28]
        dec_fea = self.upsample(dec_fea.transpose(1, 2))
        B, C, _, _ = dec_fea.shape
        #[B,64,28*28]->[B,28*28,64]
        dec_fea = dec_fea.view(B, C, -1).transpose(1, 2)
        
        # [B, HW, C]
        if self.fuse:
            # [2] fuse encoder fea and decoder fea
            #concat([B,28*28,64],[B, 28*28, 64])->[B, 28*28, 128]->[B, 28*28, 64]
            dec_fea = self.concatFuse(torch.cat([dec_fea, enc_fea], dim=2))
            #[B, 28*28, 64]
            dec_fea = self.att(dec_fea)
        return 

这里的att不同于以上的Token_transformer。
以上的Token_transformer是由多头Attention+MLP(通道数先扩大再缩小)组成。
而此处的att由token_performer和MLP(通道数保持不变)组成。

class Token_performer(nn.Module):
    def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1):
        super().__init__()
        self.emb = in_dim * head_cnt # we use 1, so it is no need here
        self.kqv = nn.Linear(dim, 3 * self.emb)
        self.dp = nn.Dropout(dp1)
        self.proj = nn.Linear(self.emb, self.emb)
        self.head_cnt = head_cnt
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(self.emb)
        self.epsilon = 1e-8  # for stable in division

        self.mlp = nn.Sequential(
            nn.Linear(self.emb, 1 * self.emb),
            nn.GELU(),
            nn.Linear(1 * self.emb, self.emb),
            nn.Dropout(dp2),
        )

        self.m = int(self.emb * kernel_ratio)
        self.w = torch.randn(self.m, self.emb)
        self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False)

    def prm_exp(self, x):
        # part of the function is borrow from https://github.com/lucidrains/performer-pytorch 
        # and Simo Ryu (https://github.com/cloneofsimo)
        # ==== positive random features for gaussian kernels ====
        # x = (B, T, hs)
        # w = (m, hs)
        # return : x : B, T, m
        # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)]
        # therefore return exp(w^Tx - |x|/2)/sqrt(m)
        xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2
        wtx = torch.einsum('bti,mi->btm', x.float(), self.w)

        return torch.exp(wtx - xd) / math.sqrt(self.m)

    def single_attn(self, x):
        k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)
        kp, qp = self.prm_exp(k), self.prm_exp(q)  # (B, T, m), (B, T, m)
        D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2)  # (B, T, m) * (B, m) -> (B, T, 1)
        kptv = torch.einsum('bin,bim->bnm', v.float(), kp)  # (B, emb, m)
        y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon)  # (B, T, emb)/Diag
        # skip connection
        # y = v + self.dp(self.proj(y))  # same as token_transformer in T2T layer, use v as skip connection
        y = self.dp(self.proj(y))
        return y

    def forward(self, x):
        x = x + self.single_attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

你可能感兴趣的:(论文阅读笔记,transformer,深度学习,计算机视觉)