- 本文是 Make Your First GAN With PyTorch 的第 5 章,本书的介绍详见这篇文章。
在探索生成对抗网络(GAN)前,先设置一些场景进行基础的认识。
一般而言,使用神经网络是为了约简、凝练、总结信息,比如 MNIST 分类器 是一个很好的例子,网络有 784 个输入,输出 10 个值,输出比输入的数量少很多。
下面来完成一个思想实验,如果将这些网络从前到后反转,将完成约简的相反工作,也就是扩展少量数据到大量数据。在这种情况下,将能生成图像。
事实上,如果将上面训练好的网络反向,将代表数字的向量输入,可以产生 28✖️28 的图像:
但是上述过程创建的图像有如下特点(或者叫问题):
- 给定的输入,输出的图像总是相同的;
- 输出图像是对应标签所有训练图像的某种平均。(看上图似乎像 3 又不完全是 3)
考虑上面的问题,其实使用网络来生成图像的想法很好,但是理想的网络应能够:
- 创建的图像不能完全一致;
- 生成的图像应该 “更真实”,而不是训练图像的平均值。
这两个问题对生成真实和有用的图像很重要,简单的网络反向方法并不能解决这些挑战,需要一种新的网络架构。
2014 年,Ian Goodfellow 提出了一种不同类型的网络架构,该网络架构并没有更复杂,也没有使用更花哨的激活函数或更高级的优化技术;但是,这个网络的架构完全不同。
下面一步步进行解释。
下图显示了对图像进行 猫 和 非猫 分类的神经网络:
如果输入到网络的图像是一只猫,则输出值应为 1,代表 True;如果图像不是猫,则输出为 0,代表 False。
- 上图这个架构与之前 MNIST 例子区别不大,唯一的区别是分类输出值是单独一个值,而不是 10 个。
对任务作出一点小改变,改变分类器试图分辨 猫 和 非猫 的情况,制作一个分类器来分辨 真实猫 的图片和自己画的 假猫 的图片(下图所示):
- 看起来这个架构并没有什么值得注意的变化,输入仍然有两种图像,且神经网络分类器可以训练来分类这两类。
可以将这个分类器看作一个 侦探(detective)。在进行训练前,这个侦探并不擅长分辨 真猫 和 假猫;随着训练的开展,侦探将逐渐擅长从 真猫 中分辨 假猫。
下面,假设有一个能生成 假猫 图片的生成器:
其实,生成一点都不像猫的垃圾图像很容易,比如我们可以简单地画出随机的三角形。
但是,我们不满足于一个只能生成垃圾图像的生成器。假设有一个神经网络,可以通过训练来产生相对真实的图像,将这个神经网络称之为 生成器(generator)。同时,称分类器为 鉴别器(discriminator),这是类似网络中的公认名称。
下面考虑如何训练这个 生成器。所谓训练,是关于奖励哪种行为,惩罚哪种行为的工作,这也是损失函数完成的工作:
先不考虑损失函数,回过头看一下整体考虑,下面略微有点绕,需要认真阅读。
- 鉴别器 和 生成器 设置为彼此竞争,作为 对手(adversaries),每个都试图超过另一个。随着这个过程,鉴别器 和 生成器 都变得越来越好,这个架构称之为 对抗生成网络(Generative Adversarial Network) ,或简称为 GAN。
这是一个聪明的架构,并不仅仅因为它使用了竞争来驱动改进,而且因为 不需要 定义详细的规则来描述损失函数中的真实图像。
机器学习的历史已经表明,我们并 不善于 定义这些规则,取而代之的是让 GAN 自主学习真实图像是什么。
上面所描述的是真正令人兴奋的,世界领先的机器学习研究者之一 Yann LeCun 称 对抗式学习(adversarial training) 是 “过去二十年机器学习中最酷的想法” 。
在一个 GAN 中,生成器 和 鉴别器 都需要训练,而且,不要 先训练一个,之后然后再训练另一个,而是期望 生成器 和 鉴别器 能同时学习。
下面的三步训练循环是完成这个的一个方法:
- 第一步: 向 鉴别器 展示一个真实的图像,然后,告诉 鉴别器 ,这个样本的分类应该是 1;
- 第二步: 向 鉴别器 展示一个 生成器 的输出,然后,告诉 鉴别器 ,这个样本的分类应该是 0;
- 第三步: 向 鉴别器 展示一个 生成器 的输出,然后,告诉 生成器 ,这个样本的分类应该是 1。
上面三个步骤是绝大多数 GAN 训练方案的核心,可能很难理解,下面通过一些图片解释这些步骤的含义。
- 下面的介绍中,鉴别器 和 生成器 都加粗,需要认真注意,避免混淆。
步骤一 是最简单的,主要是向 鉴别器 展示一个来自真实数据集的图像,请求它对样本图像进行分类(输出 0 或者 1)。
由于这里的 预期输出 应该是 1,所以,我们可以使用 损失值来更新 鉴别器。
但这次是由 鉴别器 对一个来自 生成器 的图像进行分类(输出 0 或者 1)。
由于这时 鉴别器 的 预期输出 应该是 0,所以也可以使用 损失值 来更新 鉴别器。
- 这里操作必须很小心, 不要在这步更新 生成器,因为我们并不想在 生成器 生成的虚假图像被 鉴别器 查获假的图像时,仍给 生成器 奖励。
- 下面对 GAN 编程时,将展示如何防止通过计算图返回更新 生成器。
使用 生成器 来产生 虚假图像,用于展示给 鉴别器 来分类(输出 0 或者 1)。
这时 鉴别器 的期望输出应该是 1,也就是我们希望 生成器 能够生成足够迷惑 鉴别器 的图像。
与 步骤二 的区别在于,这里的 损失值 仅仅用来更新 生成器,而并不想在 鉴别器 分类错误时却仍然鼓励它,所以这一步并不更新鉴别器。
- 前面几个步骤看起来很复杂,但只要结合示意图认真理解,后面对 GAN 编程时, 将看到操作起来很简单。
上面刚讨论了 GAN 的训练框架,但实际上训练 GAN 可能很困难。
由于生成器和鉴别器在训练时 相互对抗,如果它们平衡的很好,那么它们将互相改进并不难;但如果鉴别器变好的太快,生成器可能永远跟不上;相反的,如果鉴别器学习的太慢,生成器将由于劣质的图像而得到奖励(训练效果同样不好)。
- 其实, GAN 是机器学习中一个新想法,目前对如何使得它更好工作的理解还处在初级阶段,这经常导致训练网络失败。
- 当然,先将 GAN 可能失效的理论讨论放到一边,首先开始构建 GAN 。当各种问题在训练中浮现时,我们再探索它。
- 分类(Classification) 是数据的约简,分类神经网络会把很多输入值减少 到一个数量少很多的输出值,每个输出值对应一类;
- 生成(Generation) 一般是数据的扩展,生成神经网络会将少数的输入 种子(seed) 值扩展到数量多很多的输出值(比如图像的像素值);
- 生成对抗网络(GAN) 有两个神经网络:一个 生成器(generator) 和一个 鉴别器(discriminator);两个网络相互竞争,称为 对手(adversaries):鉴别器 通过将训练集里面的数据分类为 真实(real),将生成器产生的数据分类为 虚假(fake) 而得到训练;生成器 则通过创建看起来足够真的数据,使得数据骗过鉴别器而得到训练;
- 可靠地设计和训练 GAN 成功很难,GAN 是新的技术,而且形容它如何 工作、为何会训练失败的理论目前还不成熟;
- 标准的 GAN 训练循环有 3 个步骤:① 使用真实的数据样本训练鉴别器;② 使用生成的数据样本训练鉴别器;③ 通过鉴别器将生成器生成的图像识别为真的图像,来训练生成器。