DL | GAN: Generative Adversarial Nets 生成对抗网络算法学习

 

题外话:

本文是生成对抗网络GAN的基础理论学习的笔记,主要是基于Goodfellow于2014年发表的论文Generative Adversarial Nets。话说GAN已经火了这么多年,今天才提笔写笔记真是不好意思,然后刚才突然惊奇地发现,今天是GAN论文release在arXiv上正好整整五年。2014年6月10日,Ian Goodfellow向学术界发表了他同导师头脑风暴之下设计的对抗网络结构,他当时估计也没想到今后会有这么多人follow自己的工作吧(不过大神的世界也说不定早就预料到了)。GAN的特殊结构、训练策略、应用场景造就其在深度学习领域中成为“独特”的一个方向。题外话ending~

 

摘要

GAN生成对抗网络,是通过对抗的训练过程来得到目标生成模型的预估参数的一种特别的网络框架。所谓的对抗训练,即同时训练两个模型:一个生成模型G和一个判别模型D,生成模型G用于模拟数据的分布,判别模型D用于判断输入的样本是由真实的训练数据采样得来,还是从生成模型G模拟得到的数据。生成模型G的训练目标是最大化判别模型D判断错误,即让D无法判断输入的数据是真是假。这个模型框架像是两个游戏竞手互相博弈,因此得名为生成对抗模型。在任意的函数空间中,存在生成模型G可以尽可能模拟出训练数据的数据分布,使得判别模型D的输出概率为0.5(无法判断真假)。其中,G和D都由神经网络定义并使用反向传播进行训练。

 

背景介绍

2014年深度学习已展露头脚,尤其在分类任务(判别模型)上取得显著的性能提升。然而,当时已有的深度生成模型的影响力就相对较小了。一方面,由于生成模型中通常包含的棘手又复杂的概率计算问题很难处理,例如最大似然估计和相关的训练策略都有涉及,另一方面,深度学习中的分段线性单元也很难应用到生成模型中。因此,Goodfellow提出了这个全新的生成模型训练策略。

这个生成模型G的训练策略,是将其置于一个对抗网络的框架中:通过训练一个判别模型,它的训练目标是判断输入的样本是否为真实数据的采样分布还是生成模型G生成的数据。作者将生成模型比作一群货币伪造者,而判别模型等同于警察,试图甄别假钞。在这个对抗的训练过程中,两个模型都尝试尽可能达到各自的目标,直到假钞和真钞无法区分。

经过这五年的时间证明,对抗生成网络GAN的特殊框架适用于多种模型和优化策略,成为深度学习领域中一个经典的框架。

 

对抗网络

当生成模型G和判别模型D都是多层感知机模型(如:CNN,RNN等)时,对抗训练的思想可以直接应用。假设生成模型G要对数据 x 的数据分布 p_g 建模,生成模型的输入噪声 z 符合一个预先定义的概率分布 p_z( z ),通常可以是标准的高斯分布,也可以由另一个卷积网络充当encoder定义。因此,生成模型可表示为 G(z; theta_g),将噪声数据 z 映射到 x 的目标数据空间中,G是一个微分可导的多层感知机模型,theta_g为生成模型的参数。

同时,令 D(x; theta_d) 表示判别模型,输出为一个概率标量。D(x) 表示输入x是采样于真实数据的概率,取值范围为[0, 1]。正如前文提到的,判别模型D的训练目标是尽可能 最大化 区分来自真实数据 x 和生成模型G生成的数据样本的能力,即D(x) ~ 1、D(G(z)) ~ 0;而生成模型G的训练目标是尽可能 最小化 log(1 - D(G(z))),也就是生成的数据G(z)让判别模型D误判为真实数据,即D(G(z)) ~ 1。

如果将D和G看作两个游戏玩家,他们在玩一个针对V(G, D)minimax最小化最大化的游戏,用公式可以表示为:

其中,内层的优化是通过优化D来最大化V,当x从训练数据中采样得来,判别模型的输出D(x)接近于1,log D(x) 接近于0;相反,当输入为噪声z,判别模型的输出D(G(z))应尽可能接近于0,从而最大化log(1 - D(G(z)))的值接近于0。外层的优化通过优化G以尽可能最小化V的值,与判别模型的目标相反,尽可能使得D(G(z))接近于1,混淆判别模型的判别能力。如果只单纯对D进行优化,模型很快会过拟合。为避免过拟合,通常会先对判别模型D进行k次优化迭代,然后对生成模型G进行1次优化,然后就此策略交替优化到D(x) ~ 0.5,G的loss值变化减慢。

整个算法过程可以用伪代码表示如下:

DL | GAN: Generative Adversarial Nets 生成对抗网络算法学习_第1张图片

值得注意的是,在训练过程的初期,生成模型G的数据表达能力较弱,判别模型D可以很容易判断出G生成的数据。也就是说log(1 - D(G(z)))很容易达到饱和。为了得到更好的生成模型,我们可以把最小化log(1 - D(G(z)))的过程转化为最大化log(D(G(z))),这样在训练初期时就可以得到比较大的梯度以加快生成模型的收敛速度。

下图是判别模型和生成模型就训练数据和噪声数据的训练过程示例。图中,判别模型的数据分布表示为(D,蓝色,段虚线),真实训练数据的分布p_x表示为(x,黑色,点虚线),而生成模型的数据分布p_g表示为(G,绿色,实线)。位于下方的两条直线表示真实数据x和噪声数据z,在这个示例中噪声数据是均匀采样的,如箭头的间隔所示。箭头的指向可以理解为噪声数据z映射到x的数据域 x = G(z)。图中(a)所示,假设对抗网络接近收敛,p_g已经近似于p_x,而D只能进行部分的正确判断。(b)图所示为内层循环的优化过程中,判别模型尽可能区分生成数据和真实数据,使得D*(x) = p_x(x) / (p_x(x) + p_g(x))。(c)图所示为生成模型G更新后,判别模型D的梯度会引导G(z)将数据分布映射到更接近于真实数据分布的数据域中。(d)在多次训练迭代过程之后,若G和D具有充分的数据表达能力,数据分布最终会达到p_g = p_x,判别模型D将无法判断数据的真假,即D(x) = 1/2。

DL | GAN: Generative Adversarial Nets 生成对抗网络算法学习_第2张图片

 

理论证明

在对抗网络结构中,生成模型G的目标是将噪声数据z映射到真实数据x的数据域中,其中z服从某个先验的已知数据分布 z ~ p_z,生成模型G(z)服从数据分布p_g。当训练时间充分,模型表达能力充分时,生成模型可以使得p_g无限接近p_x。作者在论文中指出:

1. 对抗模型的训练过程,存在一个全局最优解使得 p_g = p_x。

2. 这个全局最优解可以通过优化上述的公式V(D, G)得到。

论文中作者给出了具体详细的论证过程,在此不再赘述,感兴趣的读者可以去看原文,也欢迎大家在评论区讨论。

 

实验

2014年时,(随着时间推进,已经有太多精彩的实验^_^)作者在多个数据集上进行了实验,下面就MNIST进行具体说明。MNIST数据集中包含手写数字的数据,GAN的训练目标是生成判别模型无法区分的手写数字的图像。在此实验中,生成模型和判别模型都使用多层感知机实现,部分实验结果如下图所示。作者指出,GAN生成数据一个值得关注的问题是,需要找到恰当合适的方法和指标去评估生成数据的质量。即便没有明确的指标证明,GAN生成的数据从肉眼上观察也同样具有一定的优势。

DL | GAN: Generative Adversarial Nets 生成对抗网络算法学习_第3张图片

上图为部分GAN的生成数据,分别在数据集 a) MINIST b) TFD c) CIFAR-10 (全连接模型) d) CIFAR-10 (卷积判别模型和反卷积生成模型)上的实验结果。最右列为训练数据集中的数据,为生成数据最接近的原始数据。实验结果也可以观察到生成数据并没有完全复制训练数据。

 

优点与缺点

1. 缺点

由于无法明确地表示生成的数据分布p_g(x),判别模型D必须在训练过程中与生成模型G保持同步,尤其是G不能更新太多次而不更新D,以避免G生成太多分岔数据同时对应x。

 

2. 优点

首先是模型中不需要马尔可夫链,只需要反向传播即可进行优化,因此具有很强的模型普适性。另一个优点是GAN主要依赖于计算进行优化,不仅生成模型的优化不直接与训练数据相关,而且梯度变化还与判别模型相关。这说明生成模型的学习过程不是直接从数据中进行数据拷贝,而是在学习训练数据的数据分布。此外,对抗网络还能表示出更清晰锐化的生成数据,而以往的生成模型通常会生成比较模糊的结果。

 

结论与展望

GAN对抗生成模型具有很好的拓展性:

1. 可以很容易拓展到条件生成模型p(x | c),只需要在生成模型G和判别模型D中加入条件数据c。这个拓展在没多久之后就产生了日后图像生成GAN系列中的DCGAN,还有后续的pix2pix,以及今年初的GauGAN等等。时间真是检验经典的唯一标准。

2. 如果想得到近似效果的inference结果,可以在已知训练数据x时,增加一个辅助网络去预测z。类似于后续的style transfer应用,可以预先学习出初始噪声z的分布,以生成某种特定风格的图像。

3. 令S为训练数据x的一个子集,模型可以几乎得到所有p(x_s | x_no_s)的数据分布。即通过不包含部分数据的数据集,学习出其余数据的数据分布。

4. 半监督学习的应用。当标注的数据量有限时,判别模型D及生成数据可以用于扩充数据量,以提高分类模型的性能。

5. 效率上的提升。如果找到合适的G和D、或者更合适的噪声z的数据分布,GAN的训练过程可以很好地提升。

 

以上为GAN的学习笔记,主要依照于作者的论文Generative Adversarial Nets进行记录,全部经过我个人的理解进行描述,如果有不恰当的地方,欢迎大家指出和讨论。后续是动手用tensorflow实现GAN的训练和测试,敬请期待更新啊。(自言自语~

 

 

PS 在看一个老电影my best friend's wedding 哈哈哈,正好有这句:This too, shall pass. 祝好啊~

 

 

你可能感兴趣的:(DL,CV)