使用 Pytorch 生成对抗性网络(GANs) | 附完整代码链接

欢迎关注 “小白玩转Python”,发现更多 “有趣”

2014年,Ian Goodfellow 和他在蒙特利尔大学的同事们发表了一篇令人惊叹的论文,向世界介绍了 GANs),即生成对抗性网络。通过计算图和博弈论的创新组合,他们表明,如果给予足够的建模能力,两个互相攻击的模型将能够通过普通的反向传播进行协同训练。

模型扮演着两种截然不同的角色。给定一些真实的数据集 R,G 是生成器,试图生成看起来像真实数据的假数据,而 D 是鉴别器,从真实数据集或 G 中获取数据并标记差异。Goodfellow 比喻(这是一个很好的比喻):G 就像一个伪造者的团队,试图根据他们的输出匹配真实的数据,而 D 是一个侦探团队,试图分辨其中的差异。(只是在这种情况下,伪造者 G 永远看不到原始数据ーー只能看到 D 的判断,他们就像盲目的伪造者。)

使用 Pytorch 生成对抗性网络(GANs) | 附完整代码链接_第1张图片

在理想情况下,随着时间的推移,D 和 G 都会变得更好,直到 G 基本上成为真正文章的“大师伪造者”,D 感到茫然,“无法区分这两种分布”。

实际上,Goodfellow 展示的是 G 可以在原始数据集上执行一个非监督式学习/数据集的形式,找到以一种更低维度的方式来表示数据的方法。正如 Yann LeCun 的名言所说:unsupervised learning is the “cake” of true AI。

这个强大的技术看起来需要大量的代码才能开始,对吧?没有。使用 PyTorch,我们实际上可以用不到50行代码创建一个非常简单的 GAN。实际上只有5个部分需要考虑:

  • R:原始的、真实的数据集

  • I:进入发生器的随机噪声

  • G:试图复制/模拟原始数据集的生成器

  • D:试图分辨生成器G的输出和真实数据R的鉴别器

  • 实际的“训练”的循环,使得生成器G能够骗过鉴别器D,同时D能提防G。

1) R: 在我们的例子中,我们将从最简单的 R ー a bell curve 开始。这个函数取一个平均值和一个标准差,然后返回一个函数,这个函数用这些参数从高斯函数中提供正确的样本数据尺寸。在我们的示例代码中,我们将使用4.0的平均值和1.25的标准差。

2) I: 生成器的输入也是随机的,但是为了让我们的工作更难一点,我们使用一个统一的分布,而不是一个正常的分布。这意味着我们的模型 G 不能简单地 shift/scale 的方式复制 R,而必须以非线性的方式对数据进行处理。

3) G: 生成器是一个标准的前向图:两个隐藏层,三个线性映射。我们使用双曲正切激活函数。G 将从 I 中得到均匀分布的数据样本,并以某种方式模拟从 R ー中得到的正态分布的样本,而不会看到 R。

使用 Pytorch 生成对抗性网络(GANs) | 附完整代码链接_第2张图片

4) D: 鉴别器代码非常类似于 G 的生成器代码:一个有两个隐藏层和三个线性映射的前向图。这里的激活函数是 sigmoid。它将从 R 或 G 中获取样本,然后输出一个介于0到1之间的数值,被解释为“假的”或“真的”。换句话说,这大概是神经网络所基本能做到的。

使用 Pytorch 生成对抗性网络(GANs) | 附完整代码链接_第3张图片

5) 最后,训练循环在两种模式之间交替进行:首先在真实数据和虚假数据中训练 D;然后训练 G 愚弄 D。

使用 Pytorch 生成对抗性网络(GANs) | 附完整代码链接_第4张图片

即使你以前没有使用过 PyTorch,你也可能从代码结构中看出发生了什么。在第一部分(绿色)中,我们将这两种类型的数据推送到 D 中,并对 D 的猜测与实际的标签应用一个可区分的标准。这个推进就是“forward”步骤;然后我们显式地调用“backward()”以便计算梯度,然后在 d_optimizer.step() 调用中使用梯度来更新 D 的参数。使用 G,但不在这里训练。

然后在最后(红色)部分,我们对 G 做类似的操作(注意,我们也通过 D 运行 G 的输出) ,但是我们在这一步没有优化或改变 D。我们不希望 D 记错标签。因此,我们只调用 g_optimizer.step()。

在 D 和 G 之间进行了几千次这种“交手”之后,我们得到了什么?鉴别器 D 很快就好了(而 G 缓慢上升) ,但一旦它达到一定的能力水平,才会真正开始改善。

超过5000轮训练后,D 被训练了20次,然后在每轮训练 G 20次,G 的输出的平均值超过4.0,然后回落到一个相对稳定,正确的范围(左)。同样地,标准差最初下降到错误的方向,但随后上升到理想的1.25范围(右) ,与 R 相匹配。

使用 Pytorch 生成对抗性网络(GANs) | 附完整代码链接_第5张图片

好了,基本的统计数据最终与 R 匹配。?分布的形状看起来正确吗?毕竟,你当然可以得到一个平均值为4.0,平均标准差为1.25的均匀分布,但这不会真正匹配 R。让我们看看 G 发出的最终分布:

使用 Pytorch 生成对抗性网络(GANs) | 附完整代码链接_第6张图片

生成器G几乎完美地恢复原始数据的分布R - 并且鉴别器D在角落里畏缩,对自己喃喃自语,无法从中“说出”事实。这正是我们想要的结果,这里总共才不到50行代码。

提示:GAN很挑剔,而且很脆弱。当它们进入奇怪的状态时,往往不会在没有一点哄骗的情况下出来。运行示例代码十次(每次超过5,000轮)显示以下 10 个分布:

使用 Pytorch 生成对抗性网络(GANs) | 附完整代码链接_第7张图片

如上图所示,10 次运行中,有 8 次可以得到比较不错的分布:类似于高斯分布。但其中两次没有生成这样的分布。在一种情况下(第5次运行),有一个凹面分布,平均值约为6.0,在最后一次运行中,在-11处有一个狭窄的峰值!当您开始在几乎任何环境中应用 GAN 时,您会看到这种现象:GAN 并不像监督学习工作流程那样稳定。但是当它们工作时,它们看起来非常不可思议。

附完整代码链接:

https://github.com/devnag/pytorch-generative-adversarial-networks

·  END  ·

HAPPY LIFE

你可能感兴趣的:(python,机器学习,人工智能,深度学习,神经网络)