条件生成对抗网络(CGAN, Conditional Generative Adversarial Networks)作为一个GAN的改进,其一定程度上解决了GAN生成结果的不确定性。如果在Mnist数据集上训练原始GAN,GAN生成的图像是完全不确定的,具体生成的是数字1,还是2,还是几,根本不可控。为了让生成的数字可控,我们可以把数据集做一个切分,把数字0~9的数据集分别拆分开训练9个模型,不过这样太麻烦了,也不现实。因为数据集拆分不仅仅是分类麻烦,更主要在于,每一个类别的样本少,拿去训练GAN很有可能导致欠拟合。因此,
CGAN就应运而生了。我们先看一下CGAN的网络结构: 从网络结构图可以看到,
对于生成器Generator,其输入不仅仅是随机噪声的采样z,还有欲生成图像的标签信息。
比如对于mnist数据生成,就是一个one-hot向量,某一维度为1则表示生成某个数字的图片。
同样地,判别器的输入也包括样本的标签。这样就使得判别器和生成器可以学习到样本和标签之间的联系。
Loss如下:
Loss设计和原始GAN基本一致,只不过生成器,判别器的输入数据是一个条件分布。在具体编程实现时只需要对随机噪声采样z和输入条件y做一个级联即可