GAN生成对抗网络:由两个子网络组成,generator和discriminator,在训练过程中,这两个子网络进行着最小最大值机制,generator用随机向量输出一个目标数据分布的样例,discriminator从目标样例中区分出生成器生成的样例。generator通过后向传播混淆discriminator,依此generator生成与目标样例相似的样例。
这篇论文中,将generator换成一个分割网络(可以是任意形式的分割网络,如:FCN,DeepLab,DilatedNet……,输入是H*W*3,依次是长宽,通道数,输出概率图为H*W*C,其中C是语义种类数),这个网络对输入的图片分割输出一个概率图,使得输出的概率图尽可能的接近ground truth。其中discriminator采用了全卷积网络(输入为generator或ground truth得到的概率图,输出位空间概率图H*W*1,其中其中像素点p代表这个来自gournd truth(p=1)还是generator(p=0)。
代码
在训练中,用半监督机制,一部分是注解数据,一部分是无注解数据。
当用有注解数据时,分割网络由基于ground truth的标准交叉熵损失和基于鉴别器的对抗损失共同监督。注意,训练discriminator只用标记数据。
当用无注解数据时,用半监督方法训练分割网络,在从分割网络中获取未标记图像的初始分割预测后,通过判别网络对分割预测进行传递,得到一个置信图。我们反过来将这个置信图作为监督信号,使用一个自学机制来训练带masked交叉熵损失的分割网络。置信图表示了预测分割的质量。
输入图像 xn x n 大小为H*W*3, 分割网络表示为 s(⋅) s ( · ) ,预测概率图为 s(xn) s ( x n ) 大小为H*W*C。全卷积discriminator表示为 D(⋅) D ( · ) ,其输入有两种形式:分割预测 s(xn) s ( x n ) 和one-hot编码的gournd truth Yn Y n .
最小化空间交叉熵损失 LD L D ,其表示为:
这里使用的损失是多任务损失:
由于没有ground truth,因此这里不能使用 Lce L c e ,这里提出了用自学机制在无注解数据中利用被训练的discriminator,大意是被训练的discriminator可以生成一个置信图,即 D(S(Xn))(h,w) D ( S ( X n ) ) ( h , w ) ,这个公式用来推断预测结构足够接近gournd truth的区域。这里用一个阈值来二值化置信图, Y^=argmax(s(xn)) Y ^ = a r g m a x ( s ( x n ) ) ,使用二值化置信图,半监督损失可以定义为: