【GAN-语义分割】Adversarial learning for semi-supervised semantic segmentation

论文:https://arxiv.org/abs/1802.07934
代码:https://github.com/hfslyc/AdvSemiSeg
PASCAL-VOC2012数据集(vocdevkit、Vocbenchmark_release)介绍
适合初学者的PASCAL VOC2012数据集的下载及简单讲解

摘要

我们提出了一种使用对抗网络的半监督语义分割方法。虽然大多数现有的鉴别器 D 经过训练,可以在图像级别将输入图像分类为真实图像还是伪图像,但我们以全卷积的方式设计了一个鉴别器,以便在考虑空间分辨率的情况下,将预测的概率图与 Groud Truth 区分开。我们表明,通过将对抗性损失与分割网络上的标准CE Loss耦合,可以将所提出的鉴别器 D 用于提高语义分割的性能。另外,全卷积鉴别器 D 通过在未标记图像的预测结果中发现值得信赖的区域来实现半监督学习,从而提供附加的监督信号。与利用弱标记图像的现有方法相比,我们的方法利用未标记图像而不加任何注释来增强分割模型。在PASCAL VOC 2012数据集和Cityscapes数据集上的实验结果证明了我们算法的有效性。

介绍

我们将分割网络视为生成器。

与经过训练以生成给定噪声矢量的图像的典型生成器不同,我们的细分网络在给定输入图像的情况下输出语义标签的概率图。在此设置下,我们希望将分割网络的输出与真实标签在空间上映射的位置接近。

为此,我们采用对抗性学习方案,并提出了一种全卷积鉴别器,该模型可学习将真值标签图与分割预测的概率图区分开来。结合空间 CE Loss,我们的方法使用了对抗损失,该损失促使分割网络在高阶结构中生成接近于ground truth标签图的预测概率图。这个想法类似于使用概率图形模型,例如条件随机场(CRF)(Zheng等人,2015; Chen等人,2017; Lin等人,2016),但是在使用过程中没有额外的后处理模块测试阶段。此外,在推理过程中不需要鉴别器,因此我们提出的框架不会增加测试的计算能力。通过采用对抗学习,我们进一步利用了在半监督条件下拟议的卷积鉴别器。

这项工作的贡献如下:首先,我们开发了一个对抗框架,可以在推理过程中提高语义分割的准确性,而无需额外的计算负担。其次,我们通过利用未标记图像的鉴别网络响应来帮助训练分割网络,从而促进半监督学习

算法概述

图1显示了所提出算法的概述。我们的系统由两个网络组成:分割网络和鉴别器网络。前者可以是设计用于语义分割的任何网络,例如FCN,DeepLab,DilatedNet。给定一个尺寸为H×W×3的输入图像,分割网络将输出尺寸为H×W×C的类别概率图,其中C是目标数据集的语义类别数。

我们的鉴别器网络是一个基于FCN的网络,它将分类概率图作为输入,从分割网络或地面真值标签地图中输入,然后输出大小为Hx Wx 1的空间概率图。鉴别器的每个像素p输出映射表示该像素是从地面实况标签(p = 1)还是从分段网络(p = 0)采样。与采用固定尺寸输入图像(大多数情况下为64×64)并输出一个概率值的典型GAN鉴别器相比,我们将鉴别器转换为可完成任意大小输入的全卷积网络。重要的是,我们发现这种转变对建立对抗性学习计划至关重要。

在训练过程中,我们使用半监督设置下的标签图像和未标签图像。当使用标记的数据时,分割网络由标准的交叉熵损失(具有地面真实性标签图)和对抗性损失(由鉴别器网络)监督。请注意,我们仅使用标签数据来训练鉴别器网络。

对于未标记的数据,我们使用提出的半监督方法训练分割网络。从分割网络获得未标记图像的初始分割预测之后,通过将分割预测传递到鉴别器网络来获得置信度图。反过来,我们使用“自学式”方案将此置信度图视为监督信号,以训练带有mask 交叉熵损失的分割网络。直觉是该置信度图指示了预测分割的局部质量,因此分割网络知道训练期间可以信任的区域。

图1

采用对抗网络的半监督训练

训练目标

给定大小为 的输入图像 ,我们将分割网络表示为 ,将预测概率图表示为大小 的 ,其中C是类别编号。 对于我们的全卷积鉴别器,我们将其表示为 ,它输出大小为 的两类置信度图,其中 是大小为 的类概率图 ,来自地面真值标签 或分割网络 。

鉴别网络训练。为了训练鉴别器网络,我们针对两个类别将空间交叉熵损失 最小化。损失可以正式写成

其中,如果样本是从分割网络中提取的,则 ;如果样本是来自真实标签的样本,则 。 注意,鉴别器网络将C信道概率图作为输入。 为了将大小为 的真实标签图 转换为C通道,我们只需通过构造概率图 来采用独热编码方案,其中 取值为1, ,否则为0。

Luc等人(2016年)提出鉴别器网络可以通过检测概率来容易地区分概率图是否来自真实值。但是,我们在训练阶段没有观察到这种现象。一个原因是我们使用全卷积方案来预测空间置信度,这增加了学习鉴别器的难度。此外,我们尝试Luc et al。(2016)提出的Scale方案,根据分割网络输出的分布,将真实实际概率信道稍微扩散到其他信道。然而,结果显示没有差异,因此我们在实验中不采用这种方案。

分割网络训练。我们建议通过最小化多任务丢失功能损失函数来训练分割网络:

其中 , 和 分别表示空间多类交叉熵损失,对抗性损失和半监督损失。、 是两个用于平衡多任务训练的常数。

我们首先考虑使用带注释的数据的场景。给定输入图像 ,地面真值 和预测结果 ,则交叉熵损失可通过以下公式获得:

给定完全卷积的鉴别器网络 ,我们通过对抗损失 采用对抗学习。

在这种对抗性损失的情况下,我们试图通过最大程度地将分割预测视为真实分布的概率来训练分割网络,以欺骗鉴别器。

使用未标记的数据进行训练。
使用未标记的数据进行训练。现在我们考虑半监督环境下的对抗训练。对于未标记的数据,显然我们不能应用 ,因为没有可用的地面真实标签。对抗损失 仍然适用,因为它只需要鉴别器网络。然而,我们发现,仅仅在没有 的情况下对未标记的数据应用对抗性损失时,性能会退化。这是合理的,因为鉴别器用作正则化并可能过度纠正预测以适应地面真实分布。

因此,我们建议使用“自学”策略来利用带有未标记数据的训练过的鉴别器。主要思想是训练过的鉴别器可以生成一个置信图,即 ,它推断出预测结果足够接近地面真实分布的区域。然后,我们用一个阈值对这个置信图进行二值化,以突出显示可信区域。我们使用这个二值化的置信度图将自学成的基本事实定义为掩码分割预测 。由此产生的半监督损失定义为:

是指示函数, 是控制自学过程灵敏度的阈值。在训练期间,将自学成的目标 和指示函数的值视为常量,因此 可以简单地视为带有Mask的空间交叉熵损失。在实践中,我们发现这个策略在 范围在0.1到0.3之间稳健运行。

网络架构

分割网络

我们采用DeepLab-v2(Chen等,2017)框架以及在ImageNet数据集(Deng等,2009)上预先训练的ResNet-101(He等,2016)模型作为我们的分割-基线网络。然而,由于内存问题,我们没有采用Chen 中提出的多尺度融合。根据最近关于语义分割的工作实践(Chen et al。,2017; Yu&Koltun,2016),我们删除了最后一个分类层,并将最后两个卷积层的跨度从2修改为1,使分辨率为输出要素图有效地是输入图像大小的1/8倍。为了扩大感受野,我们在conv4和conv5层分别应用了步长为2,为4的空洞卷积(Yu&Koltun,2016)。在最后一层之后,我们采用了Chen提出的空洞空间金字塔池化(ASPP)等(2017)作为最终分类器。最后,我们将上采样层与softmax输出一起应用,以匹配输入图像的大小。

鉴别器网络

对于鉴别器网络,我们遵循Radfordet等人使用的结构。它由5个卷积层组成,卷积核为4×4,通道数为{64,128,256,512,1},步幅为2。每个卷积层后都有一个Leaky-ReLU(Maas等人,2013),除最后一个参数外,其参数为0.2。要将网络转换为全卷积网络,需要在最后一层添加上采样层,以将输出重新缩放为输入图的大小。请注意,我们没有采用批量归一化层。我们发现批次归一化层(Ioffe&Szegedy,2015)高度不稳定,因为该系统只能以小批量进行训练。

总结:

本文的总体思路是采用半监督的思路进行学习,半监督是将有pixel-level和image-level标注的数据集进行训练,再将训练好的网络用于没有pixel-level标注的数据集进行学习,所以在本文通过对抗网络作为语义分割网络,对于有标记的数据,首先通过交叉熵loss Lce训练分割网络,将ResNet101作为backbone,并使用ASPP作为最后一层conv,得到预分割的feature map,在判别器部分(仅使用标记的数据训练判别器),使用交叉熵loss Ld与Ladv进行训练,通过自学习方案得到置信度图(confidence MAP),通过有ground-truth的监督下, 训练判别器,使其无法分别生成的图像是数据集真实的图像还是生成的虚假的图像。对于未标记的数据,由于没有ground-truth监督,没有交叉熵loss,但Ladv同样适用,我们通过训练好的判别器逐像素设置one-hot编码,通过Lsemi辅助训练进行半监督。

参考:
2018 BMVC之GAN+seg:Adversarial Learning for Semi-Supervised Semantic Segmentation

你可能感兴趣的:(【GAN-语义分割】Adversarial learning for semi-supervised semantic segmentation)