代码实例:教你实现infoGAN

实例:构建infoGAN生成MNIST模拟数据
本例演示在MNISTt数据集上使用infoGan网络模型生成模拟数据,并且加入标签信息的loss函数同时实现了AC-GAN的网络。其中的D和G都是用卷积网络来实现的,相当于DCGAN上面的infoGAN例子。
案例描述
通过使用InfoGAN网络学习MNIST数据特征,生成以假乱真的MNIST模拟样本,并发现内部潜在特征信息。
具体实现可以分为如下几个步骤:
1. 引入头文件并加载MNIST数据
假设MNIST数据放在本地磁盘跟目录的data下。本例中将使用前面介绍的slim模块构建网络结构,所以需要引入slim。当然也可以不用slim,引入slim的目的就是为了编写代码比较方便,不用考虑输入维度即相关权重的定义。最主要是slim还对反卷积有封装,下文会用到。

代码12-1  Mnistinfogan

代码实例:教你实现infoGAN_第1张图片

2.网络结构介绍
建立2个噪声数据(一般噪声和隐含信息)与label结合放到生成器中,生成模拟样本,然后将模拟样本和真实样本分别输入到判别器中,生成判别结果,重构造的隐含信息,还有样本标签。
在优化时,让判别器对真实的判别结果为1,对模拟数据的判别结果为0来做loss。对生成器让判别结果为1来做loss。
3.定义生成器与判别器
由于是先从模拟噪声数据来恢复样本,所以在生成器中。要使用反卷积函数。这里通过2个全连接,再接入两个反卷积完成样本的模拟生成的。并且每一层都有BN归一化处理。

代码12-1  Mnistinfogan(续)


代码实例:教你实现infoGAN_第2张图片

代码实例:教你实现infoGAN_第3张图片

对于判别器的输入是真正的样本,同样的也是经过两次卷积,在接两次全连接,生成的数据可以分别接不同的输出层,来产生不同的结果:1维输出对应判别结果1 还是0;10维输出对应分类结果;2维输出对应隐含维度信息。


4.定义网络模型
令一般噪声的维度为38对应节点为z_rand,隐含信息维度为2对应节点z_con,二者都是符合标准高斯分布的随机数。将它们与one_hot转换后的标签连接一起放到生成器中。

代码12-1  Mnistinfogan(续)

代码实例:教你实现infoGAN_第4张图片

对应判别器的结果,定义了两个0和1的数组y_fake与y_real。并且将x与生成的模拟数据gen放到判别器中,得到对应的输出
5. 定义损失函数与优化器
对于判别器中判别结果的loss有两个:真实输入的结果与模拟输入的结果,将二者和在一起生成loss_d。对于生成器的loss为自己输出的模拟数据让它在判别器中为真,定义为loss_g。
剩下还要定义网络中共有的loss值:真实的标签与输入真实样本判别出的标签、真实的标签与输入模拟样本判别出的标签、隐含信息的重构误差。定义好后创建两个优化器,将它们放到对应的优化器中。
这里用了个技巧将判别器的学习率设小,将生成器的学习率设大些。这么做是为了让生成器有更快的进化速度来模拟真实数据。优化同样是用AdamOptimizer方法。具体代码如下:

代码12-1  Mnistinfogan(续)

代码实例:教你实现infoGAN_第5张图片

代码实例:教你实现infoGAN_第6张图片

所谓的AC-GAN就是在上面将loss_cr加入到loss_c中。如果没有loss_cr,令loss_c= loss_cf,对于网络生成模拟数据是不影响的。但是却会损失真实class与模拟数据见对应的信息。
6. 开始训练与测试
建立session,在循环里使用run来运行前面构建的两个优化器。

代码12-1  Mnistinfogan(续)

代码实例:教你实现infoGAN_第7张图片

代码实例:教你实现infoGAN_第8张图片

测试部分分别使用loss_d和loss_g的eval来完成。上面代码运行后得到如下输出:

代码实例:教你实现infoGAN_第9张图片

整个数据集运行3次后,模型的测试结果可以看到,判别的误差在0.57左右,基本可以认为对真假数据的无法分辨了。
7.可视化
这部分通过两种显示来可视化结果:生成原样本与对应的模拟数据图片、生成隐含信息对应的图片。生成原样本与对应的模拟数据图片会将对应的分类以及预测分类还有隐含信息一起打印出来。生成隐含信息对应的图片中将在整个[0,1]空间里抽样与样本的标签混合一起,生成模拟数据。

代码12-1  Mnistinfogan(续)

代码实例:教你实现infoGAN_第10张图片

代码实例:教你实现infoGAN_第11张图片

上面代码运行后,生成如下结果:

代码实例:教你实现infoGAN_第12张图片

代码实例:教你实现infoGAN_第13张图片

在上面的结果中,可以很容易观察到,除了可控的类别信息一致外,隐含信息中某些维度具有非常显著的语义信息,例如:第二个元素“2”的第一个维度数值很大,表现出来的就是倾斜很大,同样的第五个元素“4”会看上去粗一些,这与其第二个维度的数值很大也是有关的。所以显然网络模型已经学到了MNIST 数据集的重要信息(主成分)。将隐含信息对应0、1间的数值抽样配合类别标签的图像生成如下:

代码实例:教你实现infoGAN_第14张图片

更多章节请购买《深入学习之 TensorFlow 入门、原理与进阶实战》 全本


你可能感兴趣的:(代码实例:教你实现infoGAN)