通过 “猫片” 认识生成对抗网络思想(Make Your First GAN With PyTorch 第五章)

  • 本文是 Make Your First GAN With PyTorch 的第 5 章,本书的介绍详见这篇文章。

本文目录

  • 1.如何生成图像?
  • 2.对抗性训练(Adversarial Training)
  • 3.训练 GAN
    • 3.1 ☆☆☆步骤一☆☆☆
    • 3.2 ☆☆☆步骤二☆☆☆
    • 3.3 ☆☆☆步骤三☆☆☆
  • 4. GAN 并不容易被训练
  • 5. 学习要点


在探索生成对抗网络(GAN)前,先设置一些场景进行基础的认识。

1.如何生成图像?

一般而言,使用神经网络是为了约简、凝练、总结信息,比如 MNIST 分类器 是一个很好的例子,网络有 784 个输入,输出 10 个值,输出比输入的数量少很多。

通过 “猫片” 认识生成对抗网络思想(Make Your First GAN With PyTorch 第五章)_第1张图片
下面来完成一个思想实验,如果将这些网络从前到后反转,将完成约简的相反工作,也就是扩展少量数据到大量数据。在这种情况下,将能生成图像。

事实上,如果将上面训练好的网络反向,将代表数字的向量输入,可以产生 28✖️28 的图像:

通过 “猫片” 认识生成对抗网络思想(Make Your First GAN With PyTorch 第五章)_第2张图片

但是上述过程创建的图像有如下特点(或者叫问题):

  • 给定的输入,输出的图像总是相同的;
  • 输出图像是对应标签所有训练图像的某种平均。(看上图似乎像 3 又不完全是 3

考虑上面的问题,其实使用网络来生成图像的想法很好,但是理想的网络应能够:

  • 创建的图像不能完全一致;
  • 生成的图像应该 “更真实”,而不是训练图像的平均值。

这两个问题对生成真实和有用的图像很重要,简单的网络反向方法并不能解决这些挑战,需要一种新的网络架构。

2.对抗性训练(Adversarial Training)

2014 年,Ian Goodfellow 提出了一种不同类型的网络架构,该网络架构并没有更复杂,也没有使用更花哨的激活函数或更高级的优化技术;但是,这个网络的架构完全不同。

下面一步步进行解释。

下图显示了对图像进行 非猫 分类的神经网络:

通过 “猫片” 认识生成对抗网络思想(Make Your First GAN With PyTorch 第五章)_第3张图片
如果输入到网络的图像是一只猫,则输出值应为 1,代表 True;如果图像不是猫,则输出为 0,代表 False

  • 上图这个架构与之前 MNIST 例子区别不大,唯一的区别是分类输出值是单独一个值,而不是 10 个。

对任务作出一点小改变,改变分类器试图分辨 非猫 的情况,制作一个分类器来分辨 真实猫 的图片和自己画的 假猫 的图片(下图所示):

通过 “猫片” 认识生成对抗网络思想(Make Your First GAN With PyTorch 第五章)_第4张图片

  • 看起来这个架构并没有什么值得注意的变化,输入仍然有两种图像,且神经网络分类器可以训练来分类这两类。

可以将这个分类器看作一个 侦探(detective)。在进行训练前,这个侦探并不擅长分辨 真猫假猫;随着训练的开展,侦探将逐渐擅长从 真猫 中分辨 假猫

下面,假设有一个能生成 假猫 图片的生成器:

通过 “猫片” 认识生成对抗网络思想(Make Your First GAN With PyTorch 第五章)_第5张图片
这省了准备各种 假猫 图片,而是使用代码生成它们。

其实,生成一点都不像猫的垃圾图像很容易,比如我们可以简单地画出随机的三角形。

但是,我们不满足于一个只能生成垃圾图像的生成器。假设有一个神经网络,可以通过训练来产生相对真实的图像,将这个神经网络称之为 生成器(generator)。同时,称分类器为 鉴别器(discriminator),这是类似网络中的公认名称。

通过 “猫片” 认识生成对抗网络思想(Make Your First GAN With PyTorch 第五章)_第6张图片
下面考虑如何训练这个 生成器。所谓训练,是关于奖励哪种行为,惩罚哪种行为的工作,这也是损失函数完成的工作:

  • 如果生成器生成了能 通过 鉴别器的图像,我们就 奖励 它;
  • 如果生成器生成的图像 未能 通过鉴别器,我们就 惩罚 它。

先不考虑损失函数,回过头看一下整体考虑,下面略微有点绕,需要认真阅读。

  • 鉴别器 的工作是分辨真实图像和生成的图像,如果 生成器 并不太好,这项工作将很简单。
  • 但是如果训练 生成器(先不管怎么训练)效果很好,可以使得图像看起来越来越真实。
  • 另一方面,如果 鉴别器 随着训练变得越来越好,为了能够更有效地 “骗过” 鉴别器生成器 必须也变得越来越好。
  • 最后的结果,就是 生成器 可以变得擅长创造图像,使得生成的图像不能被 鉴别器 所区分。
  • 鉴别器生成器 设置为彼此竞争,作为 对手(adversaries),每个都试图超过另一个。随着这个过程,鉴别器生成器 都变得越来越好,这个架构称之为 对抗生成网络(Generative Adversarial Network) ,或简称为 GAN

这是一个聪明的架构,并不仅仅因为它使用了竞争来驱动改进,而且因为 不需要 定义详细的规则来描述损失函数中的真实图像。

机器学习的历史已经表明,我们并 不善于 定义这些规则,取而代之的是让 GAN 自主学习真实图像是什么。

上面所描述的是真正令人兴奋的,世界领先的机器学习研究者之一 Yann LeCun对抗式学习(adversarial training)“过去二十年机器学习中最酷的想法”

3.训练 GAN

在一个 GAN 中,生成器鉴别器 都需要训练,而且,不要 先训练一个,之后然后再训练另一个,而是期望 生成器鉴别器 能同时学习。

下面的三步训练循环是完成这个的一个方法:

  • 第一步:鉴别器 展示一个真实的图像,然后,告诉 鉴别器 ,这个样本的分类应该是 1
  • 第二步:鉴别器 展示一个 生成器 的输出,然后,告诉 鉴别器 ,这个样本的分类应该是 0
  • 第三步:鉴别器 展示一个 生成器 的输出,然后,告诉 生成器 ,这个样本的分类应该是 1

上面三个步骤是绝大多数 GAN 训练方案的核心,可能很难理解,下面通过一些图片解释这些步骤的含义。

  • 下面的介绍中,鉴别器生成器 都加粗,需要认真注意,避免混淆。

3.1 ☆☆☆步骤一☆☆☆

通过 “猫片” 认识生成对抗网络思想(Make Your First GAN With PyTorch 第五章)_第7张图片
步骤一 是最简单的,主要是向 鉴别器 展示一个来自真实数据集的图像,请求它对样本图像进行分类(输出 0 或者 1)。

由于这里的 预期输出 应该是 1,所以,我们可以使用 损失值来更新 鉴别器

3.2 ☆☆☆步骤二☆☆☆

通过 “猫片” 认识生成对抗网络思想(Make Your First GAN With PyTorch 第五章)_第8张图片
步骤二 同样是对 鉴别器 进行训练。

但这次是由 鉴别器 对一个来自 生成器 的图像进行分类(输出 0 或者 1)。

由于这时 鉴别器预期输出 应该是 0,所以也可以使用 损失值 来更新 鉴别器

  • 这里操作必须很小心, 不要在这步更新 生成器,因为我们并不想在 生成器 生成的虚假图像被 鉴别器 查获假的图像时,仍给 生成器 奖励。
  • 下面对 GAN 编程时,将展示如何防止通过计算图返回更新 生成器

3.3 ☆☆☆步骤三☆☆☆

通过 “猫片” 认识生成对抗网络思想(Make Your First GAN With PyTorch 第五章)_第9张图片
步骤三 是对 生成器 进行训练。

使用 生成器 来产生 虚假图像,用于展示给 鉴别器 来分类(输出 0 或者 1)。

这时 鉴别器 的期望输出应该是 1,也就是我们希望 生成器 能够生成足够迷惑 鉴别器 的图像。

步骤二 的区别在于,这里的 损失值 仅仅用来更新 生成器,而并不想在 鉴别器 分类错误时却仍然鼓励它,所以这一步并不更新鉴别器

  • 前面几个步骤看起来很复杂,但只要结合示意图认真理解,后面对 GAN 编程时, 将看到操作起来很简单。

4. GAN 并不容易被训练

上面刚讨论了 GAN 的训练框架,但实际上训练 GAN 可能很困难。

由于生成器和鉴别器在训练时 相互对抗,如果它们平衡的很好,那么它们将互相改进并不难;但如果鉴别器变好的太快,生成器可能永远跟不上;相反的,如果鉴别器学习的太慢,生成器将由于劣质的图像而得到奖励(训练效果同样不好)。

  • 其实, GAN 是机器学习中一个新想法,目前对如何使得它更好工作的理解还处在初级阶段,这经常导致训练网络失败。
  • 当然,先将 GAN 可能失效的理论讨论放到一边,首先开始构建 GAN 。当各种问题在训练中浮现时,我们再探索它。

5. 学习要点

  • 分类(Classification) 是数据的约简,分类神经网络会把很多输入值减少 到一个数量少很多的输出值,每个输出值对应一类;
  • 生成(Generation) 一般是数据的扩展,生成神经网络会将少数的输入 种子(seed) 值扩展到数量多很多的输出值(比如图像的像素值);
  • 生成对抗网络(GAN) 有两个神经网络:一个 生成器(generator) 和一个 鉴别器(discriminator);两个网络相互竞争,称为 对手(adversaries)鉴别器 通过将训练集里面的数据分类为 真实(real),将生成器产生的数据分类为 虚假(fake) 而得到训练;生成器 则通过创建看起来足够真的数据,使得数据骗过鉴别器而得到训练;
  • 可靠地设计和训练 GAN 成功很难,GAN 是新的技术,而且形容它如何 工作、为何会训练失败的理论目前还不成熟;
  • 标准的 GAN 训练循环有 3 个步骤: 使用真实的数据样本训练鉴别器; 使用生成的数据样本训练鉴别器; 通过鉴别器将生成器生成的图像识别为真的图像,来训练生成器。

你可能感兴趣的:(Pytorch,Make,First,GAN,With,PyTorch,生成对抗网络,pytorch,深度学习)