GAN的学习 - 实际的训练过程

20200824 -

0. 引言

前面的文章中,了解了GAN的基础知识,同时介绍了实际代码(如何利用GAN来生成MNIST的数据),前面两篇文章中,进行了具体的基础知识铺垫,但是我在实际的训练过程中,也感受到了GAN训练不稳定。本篇文章从这个角度来入手,来讲讲训练过程。
前文回顾:
本篇文章仅仅记录了但是在自己进行这部分实验过程中的一些疑惑,实际上并没有解决问题。关于具体的训练过程,包括在论文中的算法实现内容,查看文章《GAN的编写 - tensorflow形式(tensorflow与GAN同学习,重点分析训练过程)》

1. GAN的训练过程

1.1 GAN如何训练

在前文中提到GAN是通过训练两个神经来实现生成模型,其中两个模型分别是生成器模型和判别器模型。在训练过程中,判别器将尽可能提高识别真数据和假数据的能力,而生成器将尽可能生成真的数据;判别器使用真实数据和假数据一起来进行训练来更新权值,而生成器只能使用完整模型(冻结的判别器,输入假的数据,但标记为真)通过判别器的误差来反馈到前面实现权值更新。
前面的这些话中,透露出来这样的一个信息,生成器的权值更新是完全有判别器来决定的。那么如果是这样的话,如果判别器判别能力强,这样输入假数据(带真标签)就会形成比较大的误差,这样生成器就能有较大的权值更新;而如果是判别器判别能力弱(也可能是已经能够生成了真实数据),那么此时生成器的权值更新也不会很大了。
本身GAN的训练过程,应该是一个minimax过程,其中判别器将最小化判别真假数据过程中的误差,而生成器会最大化判别器判断出错的概率。

1.2 GAN的损失函数

这里不具体展开GAN的损失函数的数学原理,从原始的函数,到现在很多优化的变种,有很多损失函数。这里仅仅考虑原始涉及的GAN的损失函数。在实际训练过程中,分为两个训练过程,一个是判别器使用真实数据和假数据同时来进行训练;另一个是输入假数据(但标签为真的)到组合生成的GAN中,利用这部分误差来更新生成器。
也就是说,在训练过程中将有两个损失数值产生,一个是判别器的损失,一个是GAN的损失。书写这部分内容的原因,也正是因为我在训练过程中产生了疑惑才编写的。
正常情况下,在训练普通的神经网络时,例如分类过程,那么肯定是希望损失函数越小越好,此时分类的精确率也会越来越高。但是在这里却不一样。
根据前文的描述,我们首先来描述一下两个损失数值应该的趋势;但是因为我也不知道应该往什么方向来训练,这一块我也比较疑惑。那么现在具体描述一下。

首先,判别器在一开始的时候,应该损失函数比较高,因为一开始的权值都没有设定好,并没有学习到真正的用于分类的权值;但是一旦这个模型学习到了之后,应该就会降低。

然后,生成器的损失函数部分由完整的GAN部分生成,这部分的损失函数,实际上是由于传递了真标签的假数据来决定。正常情况下,一开始的时候,可能判别器的性能应该是比较好的,这个时候的算是函数比较低,此时生成器训练过程肯定是很慢的。但是最终的趋势应该是,判别器的损失数值很小,也就是说生成器已经生成了非常好的样本,此时传递进去真标签的假数据,那么判别器已经分辨不出来了。

我个人理解应该是这样一个流程。
但是这里书写这篇文章的原因,也是因为我在实践过程中,发现了一些不同的地方,或者是说不理解的趋势。

我本身要做的事情,是利用GAN网络来生成一个SIN一维信号。在判别器和生成器中,都利用LSTM作为一层。判别器就是鉴定这个数据是不是假的,而生成器就是接受一个随机的数值,然后输入到生成器中,然后利用LSTM的运算,实现最后的信号曲线。
但是,我看到的损失函数趋势是,混合GAN的损失函数越来越高,而判别器的损失数值越来越低,最后变为了0。我不是很清楚我什么地方做错了,网络是我自己涉及的,但是GAN的整体流程是看的别人的。
这里问题出在了什么地方我也不是很清楚。这里先这样记录着,看看后续是不是还能有所改进。

这篇文章将会在后续过程中继续更新。

(20200826 更新)
我之前的时候,可能没弄清楚这到底是怎么回事。
我觉得,我这里要理解的问题是,在比较理想的情况下,GAN的两个模型在训练过程中,应该是一个怎么样的趋势。可能一开始是震荡的,但是最终是不是收敛到一个值,这个收敛的值是不是固定的呢?之前的时候我看到过他们说好像是0.5,但是我始终没有看到过。
我当时看到的结果就是一个判别器变成了0,而生成器一直增大。

(20200901 更新)
在文章[1]中提到,如果达到的训练的稳定点(可能需要非常长的时间和非常多的训练迭代次数)时,判别器会判定生成器的生成数据为0.5,这一点是不是可以作为参考的地方呢?!当然,这仅仅时单个实例的概率。

参考

[1]Generative Adversarial Nets in TensorFlow

你可能感兴趣的:(深度学习,python,神经网络,深度学习,GAN)