GAN初探:模拟高斯分布

GAN模拟高斯分布

实验目的

在《Generative Adversarial Nets》这篇论文中,用到了GAN模型去模拟原始数据的分布,即使得 p g = p d a t a p_g = p_{data} pg=pdata,现在使用一维高斯分布输入模型,将 p g = p d a t a p_g = p_{data} pg=pdata的逼近过程,用可视化的图形表示,看看这个过程是如何实现的。

实验过程

数据集

在整个实验的过程中,所有的数据都是用numpy生成的均值为3,标准差为1的高斯分布。数据长度为1*500,数据集大小为1000。

模型设计

  • 生成器:使用与《Generative Adversarial Nets》论文中相同的模型(简单的多层感知机网络),输入为1*100的随机噪声,输出为1*500的模拟分布
  • 判别器:同样使用与《Generative Adversarial Nets》论文中相同的模型(简单的多层感知机网络),输入为1*500的数据分布,输出为概率值

训练方式

同样与论文中相同,生成器与判别器各训练一次,一共2000个epoch

训练结果及分析

在原论文中,使用的数据是MNIST数据集,为了探究分布的逼近过程,简单地将数据进行替换是行不通,如下图给出的均值标准差变化:

GAN初探:模拟高斯分布_第1张图片

可以看到,均值和标准一直在震荡,而且随着训练伦次的增加,震荡的幅度并没有减小。我尝试过修改网络参数,会不会是参数不合适,在试过多个Loss,学习率等参数的调整后,并没有实质性的变化。

输出生成器的模拟数据,如下图所示:

图中给出的是分布直方图与核函数估计(学生知道的是核函数估计是与概率密度分布差不多的东西,核函数曲线的积分面积为1,可以描述数据分布的情况),在图中可以看到虽然均值和标准差是一直在震荡,但是模型的密度分布一直满足高斯分布的"形状"。可以理解为生成器输出为原数据的放缩,虽然均值和标准差与原数据不同,但每次输出的数据都满足一定的高斯分布。

模型改进

为了解决均值和标准差震荡的问题,学生认为这不是参数调整不当的问题,应该从模型本身入手才有可能解决,通过查阅网络上的资料,以及参考前人的例子,我发现可以通过修改判别器解决震荡问题。修改后的模型如下

  • 生成器:无变化
  • 判别器:使用多层感知机实现,但输入为描述高斯分布的四个特征,分别为均值、标准差、峰度与偏度,输出为概率值

在修改后使用相同的数据集训练,输入数据的均值与标准差变化如下:

GAN初探:模拟高斯分布_第2张图片

在图中可以看到,均值和标准差的变化在多次训练后趋于稳定,生成器很好的学习到了均值和标准差信息。

输出数据的密度分布呢如下所示:

从图中可以看到,当判别器没有使用整体的数据输入时,密度分布图就开始跑偏,在开始的(300, 600)伦次的训练中,密度分布其实还有高斯分布的“形状”但是在多几轮训练后,输出的数据就过分的注重均值和方差,数据大部分集中在3附近,其余大多在1.5和4.5附近,其实就是有些过拟合了。

实验总结

通过两周两种模型的尝试,我发现对于生成器而言,判别器的引导作用相当重要。不同于普通深度学习的标签引导,判别器是个很灵活的"老师",能够通过设计不同的判别器,让同样的模型朝着不同的方向发展。回到最初的实验目的, p g p_g pg是如何逼近 p d a t a p_{data} pdata的?我认为在本次实验里,起主要作用的是判别器,在原论文中,使用的数据是图像,只要生成器输出的数据满足数据集“形状”大致相同,就能通过判别器的检验了,所以能够得到比较好的效果。但是在高斯分布中,具体到均值和标准差这样很细致的点,由于判别器不注重这方面的信息,所以模型就训练不到,生成的数据自然就跑偏,这个在改进的模型中就得到了证明。

后续可以做的点

  1. 阅读近几年的论文,看看最新的GAN如何设计,能够解决什么样的问题,从而可以找到一些启发
  2. 设计混合输入模型,把上面两种模型结合

第一点其实是一直都要做的,学生现在读的论文还比较少,复现的不多,我觉得应该多看看近期的论文,看看GAN针对具体的问题都做了哪些的改进。第二个点是我自己想到的,学生上网查了下还真有混合输入的模型(好像不是什么好事,说明别人都做过了T_T),就当作学习一下前人的工作。在多输入模型里,我觉得可以把传统的算法使用的参数给添加到判别器里,这样把机器学习黑盒式判别器和传统算法结合起来,说不定有不一样的事情会发生。

你可能感兴趣的:(人工智能,论文学习,网络,算法,机器学习,深度学习,python)