深度学习【40】Improved Techniques for Training GANs

该论文提出了一些关于训练GAN的技巧,在mnist上生成的样本人类无法分辨真假,在CIFAR-10上生成的样本人类分辨的错误率为21.3%。

优化

feature matching

为G网络加了一个损失函数:
这里写图片描述
函数f表示D网络最后输出层的前一层特征图。f(x)由真实数据抽取而来,f(G(z))为G网络生成的图片抽取而来。

Minibatch discrimination

GAN训练过程中经常会出现G网络生成的图片为了能够欺骗D网络,而生成仅仅能够让D网络认为是真实的图片。也就是G网络生成的图片都太相似了,没有多样性。这是因为D网络没有一个能够告诉G网络,应该生成不相似的图片。为此作者提出了一个minibatch discrimination来解决这个问题。

minibatch disrimination通过计算一个minibath中样本D网络中某一层特征图之间的差异信息,作为D网络中下一层的额外输出,达到每个样本之间的信息交互目的。具体的,假设样本 xi 在D网络中某一层的特征向量为 f(xi)RA ,然后将 f(xi) 乘以一个张量 TRABC 得到张量 MiRBC 。然后对每个样本之间的M的行向量计算L1距离,得到 cb(xi,xj)=exp(||MibMjb||L1)R ,然后将所有的 cb(xi,xj) 相加得到 o(xi)b ,最后将B个 o(xi)b 并起来得到一个大小为B的向量 o(xi) :
深度学习【40】Improved Techniques for Training GANs_第1张图片
上述过程,如图所示:
深度学习【40】Improved Techniques for Training GANs_第2张图片

接着,将 o(xi) f(xi) 合并成一个向量作为D网络下一层的输入。

Historical averaging

加入了一个惩罚项,找来找去不知道具体怎么实现的。就不多说了。

One-sided label smoothing

标签平滑,比较操作起来比较简单。训练D网络的时候,生成真实图片的label时将1改成0.9就可以了。

Virtual batch normalization

DCGAN使用了BN,取得了不错的效果。但是BN有个缺点,即BN会时G网络生成一个batch的图片中,每张图片都有相关联(如,一个batch中的图片都有比较多的绿色)。
为了解决这个问题可以使用Reference batch normalization。Reference batch normalization(包含运行网络两次: 第一次是对一个minibatch的参考样本, 这里的参考样本是在训练开始以前被采样并且是保持不变的; 另一个是对当前的minibatch的样本进行训练。 特征的平均值和标准差使用参考样本的batch进行计算。 然后,使用这些统计的信息对两个batch的特征进行标准化处理。 此方法的一个缺点是模型容易对参考batch的样本过拟合。 为了稍微缓解此问题, 作者提出了virtual batch normalization, 对一个样本标准化时使用的统计信息是通过此样本与参考batch的联合来进行计算的。

Semi-supervised learning

作者还加入了半监督学习机制。说起来也很简单,就是在D网络中加入一个图片类别预测(比如imageNet的1000个类别)。损失函数变为:
深度学习【40】Improved Techniques for Training GANs_第3张图片
其中K表示K个类别, pmodel(y=K+1|x) 表示 x 为假的概率,相对于之前的1-D(x)。

实验结果

在imageNet上与DCGAN的对比
深度学习【40】Improved Techniques for Training GANs_第4张图片
左边的是DCGAN,右边是论文的结果。明显会比DCGAN更好一些。

你可能感兴趣的:(深度学习)