深度学习【46】pix2pixHD

英伟达的pix2pixHD,能够合成高清的2048*1024图片,简直振奋人心。pix2pixHD是我之前介绍过的pix2pix的改进版本,使用多由粗到精的G网络和多尺度D网络(每个D网络都是用了pix2pix中同样的patch技术)。

论文一开始用pix2pix尝试合成更高分辨率的图片,但发现效果不好,训练也很不稳定。所以就对pix2pix进行了一顿魔改。我们来看看他们是怎么改进的。

由粗到精的G网络

直接上图:
深度学习【46】pix2pixHD_第1张图片

从图上可以看出,整个G网络其实是由两个子网络构成:G1(棕色框里面的)和G2(两个黑色框里面的)。G1网络负责生成1024*512的图片,而G2(G1的信息会输入到G2中)则生成2048*1024的图片。G1网络由一个用来下采样的前端网络 G(F)1 ,一系列的ressidual block GR1 以及一个用来上采样的后端网络 GB1 构成。G2网络也差不多。

训练的时候,先训练G1网络,然后训练G2网络,最后联合一起训练G1和G2网络。

多尺度D网络

为了能够分辨高分辨率的真实和合成图片,D网络需要比较大的感受野。经过更深的网络或这更大的卷积核可以有大的感受野,但是这也可能会容易过拟合,并且所需的GPU内存也就更大。(对于高分辨率来说,会需要很多的显存。)

为此作者提出了多尺度D网络。他们使用了3的D网络,每个网络分的输入是不同尺度的图片。这样就达到了D网络有大的感受野的目的。所以原始的GAN变成了一个多任务的GAN:
这里写图片描述

多样性损失函数

与论文Improved Techniques for Training GANs 中feature matching技术一样。pix2pixHD也加入了feature matching技术,不过更为激进。他们将D网络中所有层(除了输出层)的特征图都拿过来做feature matching了。
这里写图片描述

其中,k表示第k个D网络,i表示D网络中的第i层。s表示带转换图片,x表示转换目标图片,G(s)表示G网络生成的目标图片。

加入feature matching损失函数后,pix2pixHD的损失函数为:
这里写图片描述

使用Instance Maps

这个主要是针对cityscapes这个任务的一个优化。其实就是在G网络中加入了 instance maps(或者说物体的边沿信息)。这是因为cityscapes数据集的label map中多个物体连在一起的时候是没有边界信息的,这样不利于合成图片。而加入了边沿信息后就能够区分同label的不同物体了。

深度学习【46】pix2pixHD_第2张图片

上图说明了label map中多个同样label物体连在一起的问题(图片(a)中蓝色的车,好几辆车连在一起,根本分不清他们的边界)。图片(b)的边界信息将这些车分开了。

学习个体级别的特征向量

我们都知道,就算同一种label的物体,其外形也更能是多种多样的。为了让G网络生成的图片也有这样的功能。作者设计了一个新的G网络:
深度学习【46】pix2pixHD_第3张图片

上图中的G网络在原始的输入labels map中多了一个features map,正是这个features map控制了同一label不同物体的多样性。

我们解释一下这个features map怎么来的,特别需要理解一下在测试的时候怎么获取这个features map。

为了生成这个features map,作者加入了一个标准的自编码网络(绿框中的模型)。这个自编码网络的输出是一个大小与输入图片一样的feature map。这个网络一开始会跟G网络和D网络一起训练,其输入是目标转换图片。在将该网络的输出作为G网络的输入之前,论文将其进行一个Instance-wise(根据输入的label map进行池化)的均值池化。然后将池化后的特征传播到这个个体的所有位置(上图中G网络的输入Features可作为一个例子)。

当然在测试时,我们肯定没有目标转换图片的,因为我们就是为了生成目标转换图片。为此,作者在训练好这个自编码网络后,将训练集中的目标转换图片利用该自编码网络抽取出最后一层的特征图,然后利用Instance-wise计算出要输入G网络的features map。接着利用k均值聚类对同一abel的所有样本的均值特征向量进行聚类得到K个聚类。这样我们在测试的时候,G网络根据label map的信息,对在每个label从之前聚类好的K个中心随机挑选一个特征,来生成G网络控制多样性的输入(features map)。

实验结果

与不同模型的比较
深度学习【46】pix2pixHD_第4张图片

多样性
深度学习【46】pix2pixHD_第5张图片

你可能感兴趣的:(深度学习)