【神经网络】GAN原理总结,CatGAN

定义及原理:    

       生成器 (G)generator:接收一个随机的噪声z(随机数),通过这个噪声生成图像。G的目标就是尽量生成真实的图片去欺骗判别网络D。

       判别器(D) discriminator:对接收的图片进行真假判别。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。D的目标就是尽量辨别出G生成的假图像和真实的图像。

       GAN的主要灵感来源于博弈论中零和博弈的思想,应用到深度学习神经网络上来说,就是通过G和D不断博弈,进而使G学习到数据的分布,如果用到图片生成上,则训练完成后,G可以从一段随机数中生成逼真的图像。

      训练过程中,G和D构成了一个动态的“博弈过程”,最终的平衡点即纳什均衡点:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近 0.5(相当于随机猜测类别)

过程

  1. 第一代的Generator,然后他产生一些图片
  2. 训练产生第一代discriminator,能够区分人工产生的和真实的图片
  3. 训练第二代Generator,使其产生的图片骗过第一代discriminator
  4. 以此类推。。。

优点

  1. 只用到了反向传播
  2. 相比其他所有模型, GAN可以产生更加清晰,真实的样本
  3. GAN应用到一些场景上,比如图片风格迁移,超分辨率,图像补全,去噪,避免了损失函数设计的困难,不管三七二十一,只要有一个的基准,直接上判别器,剩下的就交给对抗训练了

缺点

  1. 训练GAN需要达到纳什均衡,有时候可以用梯度下降法做到,有时候做不到.我们还没有找到很好的达到纳什均衡的方法,所以训练GAN相比VAE或者PixelRNN是不稳定的,但我认为在实践中它还是比训练玻尔兹曼机稳定的多
  2. GAN不适合处理离散形式的数据,比如文本
  3. GAN存在训练不稳定、梯度消失、模式崩溃的问题(目前已解决)

应用

  1. 图片生成
  2. 替换判别器为一个分类器,做多分类任务,而生成器仍然做生成任务,辅助分类器训练
  3. 和强化学习结合,目前一个比较好的例子就是seq-GAN

CatGAN

无监督的分类会被转化为一个聚类问题,通常是以某种距离作为度量准则,从而将数据划分为多个类别,而本文则是采用数据的熵来作为衡量标准构建来CatGAN (ICLR-2016) 。具体来说,对于真实的数据,模型希望判别器不仅能具有较大的确信度将其划分为真实样本,同时还有较大的确信度将数据划分到某一个现有的类别中去;而对于生成数据却不是十分确定要将其划分到哪一个现有的类别,也就是这个不确信度比较大,从而生成器的目标即为产生出那些“将其划分到某一类别中去”的确信度较高的样本,尝试骗过判别器。接下来,为了衡量这个确信程度,作者用熵来表示,熵值越大,即为越不确定;而熵值越小,则表示越确定。然后,将该确信度目标与原始GAN的真伪鉴别的优化目标结合,即得到了CatGAN的最终优化目标。

对于半监督的情况,对有标签数据计算交叉熵损失,而对无标签数据计算上面的基于熵的损失,然后在原来的目标函数的基础上进行叠加即得,当用该半监督方法进行目标识别与分类时,其效果虽然相对较优,但相对当下state-of-the-art的方法并没有比较明显的提升。但其基于熵损失的无监督训练方法却表现较好,其实验效果如下图所示,可以看到,对于如下的典型环形数据,CatGAN可以较好地找到两者的分类面,实现无监督聚类的功能。

【神经网络】GAN原理总结,CatGAN_第1张图片

GAN of Salimans et al. (2016)

参考:Improved Techniques for Training GANs

GAN网络使用梯度下降的方法只会找到低的损失,不能找到真正的纳什均衡。本论文中,作者通过引入了一些方法,提高网络的收敛。

原始的GAN网络的目标函数需要最大化判别网络的输出。作者提出了新的目标函数,motivation就是让生成网络产生的图片,经过判别网络后的中间层的feature 和真实图片经过判别网络的feature尽可能相同。

相比原先的方式,生成网络G产生的数据更符合数据的真实分布。作者虽然不保证能够收敛到纳什均衡点,但是在传统GAN不能稳定收敛的情况下,新的目标函数仍然有效。

判别网络从输入到输出逐层卷积,pooling,图片信息逐渐损失,因此中间层能够比输出层得到更好的原始图片的分布信息,拿中间层的feature作为目标函数比输出层的结果,能够生成图片信息更多,生成的图片会效果会更好。

  • Semi-supervised learning

对于GAN网络,可以把生成网络的输出作为第K+1类,相应的判别网络变为K+1类的分类问题。用Pmodel(y=K+1|x)Pmodel(y=K+1|x)表示生成网络的图片为假

你可能感兴趣的:(人工智能)