【GAN理论与WGAN】——基于李宏毅2021春机器学习课程

GAN

理论

目标

【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第1张图片

  • 输入一个Normal Distribution到Generator里,得到一个复杂的Distribution,PG。
  • 我们还有一个标准的data集,形成另一个Distribution,Pdata。
  • 因此,定义一个loss function(损失函数)来让PG和Pdata之间的Divergence(可以理解为距离)越小越好,如上图公式。

我们目标就是找一组Generator的参数(简写为G*)让divergence越小越好。

GAN有个问题——divergence不好计算

(老师就单纯说了计算divergence很复杂,不好计算,没有具体展开 =。=)
【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第2张图片
虽然我们不知道PG和Pdata的formulation,但我们只要在PG和Pdata中sample出一组数据(即Sampling from Pdata & Sampling from PG),就有办法计算divergence。

方法

通过Discriminator
【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第3张图片

  1. Discriminator看到Pdata的数据就给高分,看到PG的数据,就给低分(分数 = log( D ( Y ) ) )
  2. 可以看作一个optimization(最优化)——去maximize一个Objective Function V(D,G)问题(minimize则为Loss Function)
  3. 公式希望V越大越好,则从Pdata sample出来的分数要越大越好,从PG sample出来的则越小越好,即找一个Discriminator(简写D),在给定的D和G下来maximize这个Objective Function

写成这个公式是因为想和classifier相联系一起,在这里插入图片描述

事实上这个Objective Function就是cross entropy 乘上一个负号 ,因此我们maximize的时候等同于minimize cross entropy 即等同于训练一个classifier。
【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第4张图片

因此可以把Discriminator看成一个classifier,把从Pdata sample出来的分为class 1,把从PG sample出来的分为class 2, 训练出一个binary classifier,则等同于训练出一个D*
在这里插入图片描述
核心在于这个Objective Function的最大值与divergence有关,直观理解如下:
【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第5张图片
当我们数据的divergence很小的时候,所训练出来的discriminator就很难分辨出两者的差异,或者把discriminator看成classifier,则很难分开两个类,因此maxV(D,G)就很会很小,相反当divergence很大时,就很容易分别出差异,从而maxV(D,G)就很很大。(详细证明参见GAN原始论文)

因此我们可以将计算divergence转化为求maxV(D,G)。
【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第6张图片
我们要找一个Generator去minimize橙色方框中的值, 然后橙色方框中是给定一个Generator找一个Discriminator来让Objective Function越大越好,找到的G则是我们的目标G*
所以我们可以借助Objective Function求出D*
然后求解出G*
具体求解过程(step1,step2)参照GAN简介中的解释,也可以参照原始GAN的论文。

所以GAN非常难训练——No pain,No GAN。

Tips for GAN

首先JS divergence(GAN中常用的Divergence)不合适【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第7张图片

大多数情况,PG和Pdata重叠的部分很少

  1. 本身数据的原因:
    可以理解为:一张图片在高维空间中,就跟一条直线在三维空间中, 只有非常小的范围(这跟线上)的向量能组合成一张图片,因此PG和Pdata就类似于两条直线,除非完全重合,否则的相交的点(重叠的部分)几乎可以忽略。
  2. Sampling的原因:
    首先我们并不知道PG和Pdata具体由多少部分重叠,所以尽管PG和Pdata有重叠的部分,但我们sample 出来的点不够多,不够密集,则同样也可以画出一条把PG和Pdata1完全分离的分界线,从而看上去PG和Pdata没有重叠

而没有重叠的分布,无论原本是什么分布,算出来的JS divergence永远是log2

【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第8张图片
因此,只要两个数据集没有重合,就算新生成的数据集更接近了测试集,但是算出来的JS都是log2,G0和G1看上去都是同样的差,或者同样的好,无法更新出新的Generator参数。
而且当使用JS divergence来训练binary classifier时,最终回答道100%正确率,这是因为本身sample出来的两个PG和Pdata不够多,比如只有256张,因此机器直接将数据”背"下来,都可以将这两者分开。因此,训练GAN的时候用训练binary classifier的方法来训练discriminator,最终正确率都是100%,因此原本我们想通过每次迭代的正确率来看是否由产生出更好的Generator的目标无法实现

WGAN——把JS divergence换成 Wasserstein distance

【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第9张图片
有两个distribution P、Q,wasserstein distance计算方法:
想象驾驶一台推土机,将P这堆土移动到Q的位置所移动的平均距离d就是wasserstein distance
【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第10张图片
但我们两个分布比较复杂的时候,我们有多种推土的方法来让P变成Q,由此导致的平均推土距离d都有所不同,因此,我们最终是穷举所有可能的推土方法,找到最短的d来作为wasserstein distance
【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第11张图片
因此,我们通过计算wasserstein distance所得的d0,d1,从而找到让PG和Pdata更接近的G*

怎么计算Wasserstein distance

【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第12张图片
公式:maximize 从Pdata来的x的Discriminator的分数的期望,减去从PG来的x的Discriminator的期望,最终得到效果是从Pdata来的x的分数越高越好,从PG来的x的分数越小越好。
【GAN理论与WGAN】——基于李宏毅2021春机器学习课程_第13张图片
同时Discriminator要是一个足够平滑无剧烈变动的function(1-Lipschitz),因为:
我们需要的是real的值(从Pdata来的x的分数)越大越好,generated的值(从PG来的x的分数)越小越好,那么Discriminator就会让real变成都是+∞,generated的值变成都是-∞,就像JS divergence一样都是log2,导致无法收敛,因此当我们限制Discriminator的形状后,就不会变化到±∞,否则就是变化剧烈。

如何限制Discriminator

Spectral Normalization——一个比较好的限制方法(然后就下课了=。=)

你可能感兴趣的:(机器学习,深度学习,pytorch,GAN)