实时语义分割网络STDC原理与代码解析(CVPR 2021)

paper:Rethinking BiSeNet For Real-time Semantic Segmentation

official implementation:GitHub - MichaelFan01/STDC-Seg: Source Code of our CVPR2021 paper "Rethinking BiSeNet For Real-time Semantic Segmentation"

third-party implementation:mmsegmentation/mmseg/models/decode_heads/stdc_head.py at main · open-mmlab/mmsegmentation · GitHub

存在的问题 

为了做到实时推理,很多实时语义分割模型选用轻量骨干网络,但是由于task-specific design的不足,这些从分类任务中借鉴来的轻量级骨干网络可能并不适合解决分割问题。

除了选用轻量backbone,限制输入图像的大小是另一种提高推理速度的常用方法,但这很容易忽略边缘附近的细节和小物体。为了解决这个问题,BiSeNet采用了多路径结构将低层细节信息和高层语义信息相结合,但添加一个额外的路径来获取低层特征非常耗时,而且辅助路径总是缺乏低层信息的指导。

本文的创新点

本文设计了一种新的结构,叫做Short-Term Dense Concatenate module(STDC module),通过少量参数就可获得不同大小的感受野以及多尺度信息。将STDC模块无缝集成到U-net架构中就得到了STDC network,大大提高了语义分割任务中的网络性能。

在decoding阶段,本文抛弃了额外添加一条路径的方法,而是采用Detail Guidance来引导低层空间细节信息的学习。首先利用Detail Aggregation module来生成细节的ground truth,然后利用bce loss和dice loss来优化细节信息的学习,这可以看作是一种side-information的学习,并且在推理时不需要这个side-information。

方法介绍

Design of Encoding Network

Short-Term Dense Concatenate Module

STDC模块的结构如图(3)的(b)(c)所示

实时语义分割网络STDC原理与代码解析(CVPR 2021)_第1张图片

每个module被分成了多个blocks,\(ConvX_{i}\) 表示第 \(i\) 个block的计算,因此第 \(i\) 个block的输出计算如下

其中 \(x_{i-1}\) 和 \(x_{i}\) 分别是第 \(i\) 个block的输入和输出,\(ConvX\) 包含一个卷积层一个BN层和一个ReLU激活层,\(k_{i}\) 是卷积核大小。

在STDC模块中,第一个block的卷积核大小为1,其余的都为3。假设STDC模块的输出通道数为 \(N\),除了最后一个卷积层的卷积核数量和前一个卷积层一样,第 \(i\) 个block的卷积核数量为 \(N/2^{i}\)。在分类任务中,通常高层的通道数更多。但在分割任务中,我们更关注可变的感受野大小和多尺度信息,低层需要足够的通道用小的感受野来编码更细粒度的信息,而具有更大感受野的高层更关注高级语义信息,设置和低层一样的通道数可能会导致信息冗余。下采样只在Block2中进行。为了丰富特征信息,通过skip-path将 \(x_{1}\) 到 \(x_{n}\) 的特征拼接起来作为STDC模块的输出。

Network Architecture

网络的完整结构如图3(a)所示,一共包含6个stage,其中stage1-5中分别进行一次stride=2的下采样,stage6通过一个ConvX一个全局平均池化和两个全连接层得到最终prediction logits。

stage1&2通常作为low-level层提取外观特征,为了追求效率,每个stage中只有一个卷积block。stage3,4,5中STDC module的个数经过仔细调整确定的,其中每个stage的第一个STDC module进行下采样。STDC network的详细结构如表2所示

实时语义分割网络STDC原理与代码解析(CVPR 2021)_第2张图片

Design of Decoder

Segmentation Architecture

本文用预训练的STDC network作为encoder的backbone,并且用BiSeNet中的context path来编码上下文信息。

实时语义分割网络STDC原理与代码解析(CVPR 2021)_第3张图片

如图4(a)所示,作者使用stage3, 4, 5来生成降采样率分别为1/8, 1/16, 1/32的特征图。然后使用全局平均池化来提供具有较大感受野的全局上下文信息。接着使用U-shape结构进行上采样,并与encoding阶段对应部分(stage4, 5)进行融合。这里借用BiSeNet中的Attention Refine module来进一步提炼stage4, 5的特征。最终的预测,也借用了BiSeNet中的Feature Fusion module来融合编码阶段stage3的降采样率为1/8的特征和解码阶段对应的输出。 

最终的Seg Head包括一个3x3 Conv-BN-ReLU和一个1x1 conv,得到最终N维的输出,这里N为类别数。损失函数为交叉熵,并使用了在线困难样本挖掘OHEM。

可以看出,STDC network借用了BiSeNet的整体结构,并且直接使用了BiSeNet中ARM module和FFM module。BiSeNet的结构如下图所示,对比图4可以看出,STDC network就是将BiSeNet中的spatial path和context path合二为一,浅层输出作为spatial path,深层GAP的输出作为context path,并且重新设计了网络的结构。关于BiSeNet的介绍可见BiSeNet v1原理与代码解读_attention refinement module_00000cj的博客-CSDN博客

实时语义分割网络STDC原理与代码解析(CVPR 2021)_第4张图片

Detail Guidance of Low-level Features

实时语义分割网络STDC原理与代码解析(CVPR 2021)_第5张图片

如图5所示,其中(c)是STDC中stage3的热力图,和BiSeNet中的spatial path相比,可以看出少了很多细节,因此作者提出了Detail Guidance module来指导低层学习空间信息。具体是把细节的预测建模成一个两类的分割任务,如图4(c)所示,首先利用拉普拉斯算子从分割任务的原始ground truth中生成detail map的ground truth。如图4(a)所示,在stage3处插入一个Detail Head来生成detail feature map,然后用detail gt来引导空间细节的学习。如图5(d)所示,在增加了deail guidance module后,细节信息丰富了许多。

Detail Ground-truth Generation

细节gt的具体生成过程如图4(c)所示,对于原始的分割GT,用不同步长的拉普拉斯算子得到多尺度的细节信息,拉普拉斯kernel如图4(e)所示,然后上采样到原始大小,然后用一个可训练的1x1卷积来融合不同尺度的细节信息,最后采用阈值0.1得到最终的binary detail ground-truth。

Detail Loss

因为细节的像素点比非细节要多得多,所以这是一个类别不平衡问题。由于加权交叉熵的结果不是那么精确,因此作者采用了交叉熵和Dice loss结合的方式来优化细节的学习。因为dice loss对前景/背景像素个数不敏感,所以可以缓和类别不平衡的问题。detail loss如下所示

Detail Head具体包括一个3x3 Conv-BN-ReLU和一个1x1 conv,在推理阶段,detail head可以直接舍弃。

实验结果

STDC network与其它轻量模型在ImageNet上的结果如表5所示,可以看出STDC取得了最好的accuracy-speed balance。

实时语义分割网络STDC原理与代码解析(CVPR 2021)_第6张图片

在Cityscapes上结果如表6所示

实时语义分割网络STDC原理与代码解析(CVPR 2021)_第7张图片 

与其他real-time分割模型相比,相同速度下,STDC获得了最好的精度。

代码解析

这里以MMSeg中的实现为例,骨干网络stdc network的实现在mmseg/models/backbones/stdc.py中,具体实现过程比较简单,这里就不详述了。其中需要需要注意的是,当stride=2时,前面提到过在每个stage的第一个stdc module的block2进行downsample,但在mmseg的实现中,并不是在原来的conv中设置stride=2,而是在原来的conv前面多加了一个stride=2的conv。

ARM和FFM模块都是照搬的BiSeNet v1中的实现,没有改动,具体介绍可见BiSeNet v1原理与代码解读

设置输入batch_size=16,输入大小为480x480,网络最终输出为

outputs = [outs[0]] + list(arms_out) + [feat_fuse]
# (16,256,60,60) + [(16,128,30,30),(16,128,60,60)] + (16,256,60,60)

其中outs[0]是stage3的输出,后续需要接Detail Head。arms_out是两个ARM模块的输出,在mmseg的实现中,对这两个输出用FCN当做auxiliary head进行监督(论文中没有提到),推理阶段去除。feat_fuse就是spatial info和context info经过FFM融合后的输出,后接FCN,并用bce loss + dice loss进行优化。

Detail Head的ground truth的生成代码如下,其中提到虽然论文中说经过三个不同步长的拉普拉斯卷积后用一个可训练的1x1卷积对三者进行融合,但在官方实现和其它第三方实现中,并没有使用可训练的1x1卷积,因此这里也是用了一个不可训练参数提前设定不随训练更新的fusion_kernel进行融合。

class STDCHead(FCNHead):
    """This head is the implementation of `Rethinking BiSeNet For Real-time
    Semantic Segmentation `_.

    Args:
        boundary_threshold (float): The threshold of calculating boundary.
            Default: 0.1.
    """

    def __init__(self, boundary_threshold=0.1, **kwargs):
        super().__init__(**kwargs)
        self.boundary_threshold = boundary_threshold
        # Using register buffer to make laplacian kernel on the same
        # device of `seg_label`.
        self.register_buffer(
            'laplacian_kernel',
            torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1],
                         dtype=torch.float32,
                         requires_grad=False).reshape((1, 1, 3, 3)))
        self.fusion_kernel = torch.nn.Parameter(
            torch.tensor([[6. / 10], [3. / 10], [1. / 10]],
                         dtype=torch.float32).reshape(1, 3, 1, 1),
            requires_grad=False)

    def loss_by_feat(self, seg_logits: Tensor,
                     batch_data_samples: SampleList) -> dict:
        """Compute Detail Aggregation Loss."""
        # Note: The paper claims `fusion_kernel` is a trainable 1x1 conv
        # parameters. However, it is a constant in original repo and other
        # codebase because it would not be added into computation graph
        # after threshold operation.
        seg_label = self._stack_batch_gt(batch_data_samples).to(
            self.laplacian_kernel)  # (16,1,480,480)
        boundary_targets = F.conv2d(
            seg_label, self.laplacian_kernel, padding=1)
        boundary_targets = boundary_targets.clamp(min=0)
        boundary_targets[boundary_targets > self.boundary_threshold] = 1
        boundary_targets[boundary_targets <= self.boundary_threshold] = 0

        boundary_targets_x2 = F.conv2d(
            seg_label, self.laplacian_kernel, stride=2, padding=1)
        boundary_targets_x2 = boundary_targets_x2.clamp(min=0)

        boundary_targets_x4 = F.conv2d(
            seg_label, self.laplacian_kernel, stride=4, padding=1)
        boundary_targets_x4 = boundary_targets_x4.clamp(min=0)

        boundary_targets_x4_up = F.interpolate(
            boundary_targets_x4, boundary_targets.shape[2:], mode='nearest')
        boundary_targets_x2_up = F.interpolate(
            boundary_targets_x2, boundary_targets.shape[2:], mode='nearest')

        boundary_targets_x2_up[
            boundary_targets_x2_up > self.boundary_threshold] = 1
        boundary_targets_x2_up[
            boundary_targets_x2_up <= self.boundary_threshold] = 0

        boundary_targets_x4_up[
            boundary_targets_x4_up > self.boundary_threshold] = 1
        boundary_targets_x4_up[
            boundary_targets_x4_up <= self.boundary_threshold] = 0

        boundary_targets_pyramids = torch.stack(
            (boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up),
            dim=1)  # (16,3,1,480,480)

        boundary_targets_pyramids = boundary_targets_pyramids.squeeze(2)  # (16,3,480,480)
        boudary_targets_pyramid = F.conv2d(boundary_targets_pyramids,
                                           self.fusion_kernel)

        boudary_targets_pyramid[
            boudary_targets_pyramid > self.boundary_threshold] = 1
        boudary_targets_pyramid[
            boudary_targets_pyramid <= self.boundary_threshold] = 0

        seg_labels = boudary_targets_pyramid.long()
        batch_sample_list = []
        for label in seg_labels:
            seg_data_sample = SegDataSample()
            seg_data_sample.gt_sem_seg = PixelData(data=label)
            batch_sample_list.append(seg_data_sample)

        loss = super().loss_by_feat(seg_logits, batch_sample_list)
        return loss

你可能感兴趣的:(Real-time,segmentation,实时语义分割,深度学习,计算机视觉,人工智能)