机器学习笔记9_Generative Model(GAN)

本文是李宏毅机器学习的笔记,这是第九节,介绍了GAN网络。

文章目录

      • 1. GAN概述
        • Generator
        • Discriminator
        • 流程
      • 2. GAN原理
      • 3. Tips for GAN
      • 4. Conditional Generation
      • 5. Learning from Unpaired Data
      • 6. Evaluation of Generation

1. GAN概述

Generator

Generator可以利用一个简单的分布数据生成一个复杂的分布数据。
机器学习笔记9_Generative Model(GAN)_第1张图片
为什么不做成输入和输出相互对应的方式,而要将简单的分布当作输入呢。为什么使用分布呢?
首先看下面一个例子,预测视频的下一帧画面,从之前的画面中,小人可能向左转也可能向右转,如果这两种方式同时被预测到,那么就可能会出现,小人同时向左右转这种不符合事实的情形。
机器学习笔记9_Generative Model(GAN)_第2张图片
这时我们可以加入一个简单的分布,可以根据这个简单的分布确定是向左转还是向右转。
机器学习笔记9_Generative Model(GAN)_第3张图片
机器学习笔记9_Generative Model(GAN)_第4张图片
GAN网络的一些实现如下图所示:
机器学习笔记9_Generative Model(GAN)_第5张图片
如下图所示,GAN将简单的正太分布变为复杂的分布,将低维度的向量转变为高维度的向量。
机器学习笔记9_Generative Model(GAN)_第6张图片

Discriminator

判别器对生成器的输入进行判定,然后给定信心分数。
机器学习笔记9_Generative Model(GAN)_第7张图片
下图是生成器和判别器合作的一个例子,二者不停的进行更新,生成器想要生成判别器都认为是真的图片。判别器想要判别出所有不为真的图片。二者相互合作,共同更新。
机器学习笔记9_Generative Model(GAN)_第8张图片

流程

算法流程如下所示:
机器学习笔记9_Generative Model(GAN)_第9张图片
机器学习笔记9_Generative Model(GAN)_第10张图片
GAN可以用于生成动漫图片。
机器学习笔记9_Generative Model(GAN)_第11张图片

2. GAN原理

我们的目标是,根据简单的分布,然后生成一个复杂的分布,这个复杂的分布需要和预先设定的数据分布尽可能的相近。
机器学习笔记9_Generative Model(GAN)_第12张图片
其中 P G P_G PG P d a t a P_{data} Pdata的分布,分别从生成器的输出、已有的数据库中进行提取。
机器学习笔记9_Generative Model(GAN)_第13张图片
对于判别器,其训练过程如下所示,目标函数有点像cross entropy函数,所以可以当成一个分类任务最小话信息熵就可以。最大化目标函数与数据的散度是有关系的。
机器学习笔记9_Generative Model(GAN)_第14张图片
散度较小,不好判别,生成器也就欺骗过了判别器。散度较大,容易判别,生成器没有骗过判别器。
机器学习笔记9_Generative Model(GAN)_第15张图片
运算流程如下,生成器希望散度更小,以欺骗判别器。判别器希望散度大一些,以识别出生成器的欺骗。
机器学习笔记9_Generative Model(GAN)_第16张图片
下图是一些其它的divergence
机器学习笔记9_Generative Model(GAN)_第17张图片

3. Tips for GAN

GAN并不是一个很好训练的网络。
首先,JS divergence并不是很合适,因为如果用JS divergence的话, P d a t a P_{data} Pdata P G P_G PG几乎没有重合的,如果没有大量的数据,重合的地方也不是很多。
机器学习笔记9_Generative Model(GAN)_第18张图片
JS divergence的一大缺陷是,如果 P d a t a P_{data} Pdata P G P_G PG一点都没有重合的话,那么它们算出来的散度是一样的,这是不合适的。这回导致GAN在训练时什么都学不到。
机器学习笔记9_Generative Model(GAN)_第19张图片
更合适的是用Wasserstein distance来代替JS divergence,使用Wasserstein distance就可以避免上面的问题。
机器学习笔记9_Generative Model(GAN)_第20张图片
机器学习笔记9_Generative Model(GAN)_第21张图片
机器学习笔记9_Generative Model(GAN)_第22张图片
WGAN就是使用了Wasserstein distance的GAN,如果y属于 P d a t a P_{data} Pdata,那么为了使目标函数尽可能大,所以 D ( y ) D(y) D(y)应该尽可能大。同理当y属于 P G P_{G} PG时, D ( y ) D(y) D(y)应该尽可能小。为了不两极分化,一边趋向于正无穷,一边趋向于负无穷,所以,加了一个现实,使得D具有一定的平滑性。
机器学习笔记9_Generative Model(GAN)_第23张图片
下图展示了如何实现这种限制,传统的WGAN,就是超过最大最小的值,变为最大最小值。改进的方法包括Gradient Penalty、Keep gradient norm等方法。
机器学习笔记9_Generative Model(GAN)_第24张图片
不过尽管如此,GAN依旧不是一个容易训练的模型。
机器学习笔记9_Generative Model(GAN)_第25张图片
机器学习笔记9_Generative Model(GAN)_第26张图片
GAN可以用于语句生成。
机器学习笔记9_Generative Model(GAN)_第27张图片
机器学习笔记9_Generative Model(GAN)_第28张图片
机器学习笔记9_Generative Model(GAN)_第29张图片
机器学习笔记9_Generative Model(GAN)_第30张图片

4. Conditional Generation

条件生成,就是限制一些条件,让GAN生成一些符合条件的东西出来,例如Text-to-image
机器学习笔记9_Generative Model(GAN)_第31张图片
机器学习笔记9_Generative Model(GAN)_第32张图片
机器学习笔记9_Generative Model(GAN)_第33张图片
机器学习笔记9_Generative Model(GAN)_第34张图片
机器学习笔记9_Generative Model(GAN)_第35张图片
机器学习笔记9_Generative Model(GAN)_第36张图片
机器学习笔记9_Generative Model(GAN)_第37张图片

5. Learning from Unpaired Data

机器学习笔记9_Generative Model(GAN)_第38张图片
机器学习笔记9_Generative Model(GAN)_第39张图片
Cycle就可以从不同领域的数据进行学习。
机器学习笔记9_Generative Model(GAN)_第40张图片
生成器生成的不仅要和输入相近,还要和 P d a t a P_{data} Pdata相近。但是对于传统的GAN网络来说,只需要和 P d a t a P_{data} Pdata相近即可,就可以骗过判别器。
机器学习笔记9_Generative Model(GAN)_第41张图片
所以再加一层,根据前一个生成器的输出,输出和原输入相近的图片。
机器学习笔记9_Generative Model(GAN)_第42张图片
还可以反向操作,根据动漫人物生成真人。
机器学习笔记9_Generative Model(GAN)_第43张图片
一些其它的GAN网络。
机器学习笔记9_Generative Model(GAN)_第44张图片
机器学习笔记9_Generative Model(GAN)_第45张图片
还可以对文本的形式进行改变。
机器学习笔记9_Generative Model(GAN)_第46张图片
机器学习笔记9_Generative Model(GAN)_第47张图片

6. Evaluation of Generation

该如何验证生成图片的质量呢,使用人工的话成本太高。所以需要自己制定一些方式来判断生成图片的质量。例如将生成的图片放进图片分类器中,如果输出结果的概率分布都集中在一个类别,就说明图片分类器有很大的信心分数判定图片的类别,说明生成图片的质量很高,不是四不像。
机器学习笔记9_Generative Model(GAN)_第48张图片
GAN可能会遇到Mode Collapse问题,就是生成图片的分布至于真实数据中的几个相近,也就是总是生成相同的图片。
机器学习笔记9_Generative Model(GAN)_第49张图片
还可能会遇到Mode Dropping的问题,也就是来来回回生成的就那几张脸,多样性不够。机器学习笔记9_Generative Model(GAN)_第50张图片
下图就是评价多样性的一种方式,将每张生成图片经过分类器产生的概率分布加起来,然后求平均,得到平均分布,如果这个分布很平均,那么就说明散度越大,多样性越多,图片的质量越好。评价指标可以是Inception Score(IS),diversity越大,图片质量越高。
机器学习笔记9_Generative Model(GAN)_第51张图片
由于相同的人物,如果眼睛的颜色不一样,也会有一定的diversity,但是人看来,这两个人物是差不多的。所以用IS来进行评估可能会有一些缺陷。所以可以采用FID进行评估。蓝色的点代表生成的图片在分类器的隐藏层输出的分布。红色的点代表真实图片在分类器的隐藏层输出的分布。然后计算这两个之间的FID。越小代表图片质量越高。
机器学习笔记9_Generative Model(GAN)_第52张图片
一些其它的GAN网络
机器学习笔记9_Generative Model(GAN)_第53张图片
GAN的记忆性,这不是好的,与现实数据太相近了。
机器学习笔记9_Generative Model(GAN)_第54张图片
解决方法的文献。机器学习笔记9_Generative Model(GAN)_第55张图片

你可能感兴趣的:(深度学习笔记,机器学习)