注:本文是我学习李宏毅老师《机器学习》课程 2021/2022 的笔记(课程网站 ),文中图片均来自课程 PPT。欢迎交流和多多指教,谢谢!
本节课将介绍 Deep Neural Network 作为 Generator 的应用。前面的课程中,Network 的输入输出是固定的数据。Network 作为 Generator 时,模型的输入包含了从简单分布(这样知道其分布函数)中采的 sample,如下图所示。输入的随机性传递到输出,因此输出是一个复杂的分布。
为什么要用分布 (distribution) 呢?
有些情况下,一个输入有多种可能的输出,不是固定的一对一映射。也就是说,这种模型架构适合“没有标准答案”的情况,例如绘画,聊天机器人 (chat-bot) 等创作型任务。
我们先来看简单的情况,即不考虑输入的条件 x x x ,输入输出都是 distribution 的 GAN 网络。
以下步骤中可以把 Generator 想象成制假钞的人,Discriminator 想象成警察。
第一步:训练 Discriminator 分辨真伪的“火眼金睛”: 此时 generator G 保持不变,只训练 discriminator D。
Discriminator 训练目标:把真数据 (real ones) 和“假”数据 (generated ones) 区分开。具体地说,就是 real ones 经过 D,输出值大(接近 1),generator 产生的数据 (generated ones) 经过 D,输出值小(接近 0),如下图所示:
第二步:Generator 也在提升自己的“造假技术”:此时 discriminator D 保持不变,只训练 generator G。
Generator 训练目标:产生的“假”数据 (generated ones) 让 discriminator 都误认为是真的 (real ones)。具体地说,使 generated ones 经过 discriminator 的输出值越大越好(接近 real ones 的值 1)。实际中可以把 generator 和 discriminator 组成一个大的网络结构,如下图所示,前几层为 generator,后几层为 discriminator,中间 hidden Layer 的输出就是 generator 的输出。这一步固定后几层 (D) 不变,只更新前几层(G)。
结合起来,GAN 就是在不断重复循环上述两步,如下图所示:
我的思考:Generator 或 Discriminator 两个不会同时训练,每次只有一个在训练。训练时,所掌握的是对方上一轮训练的信息。其实这就像是双方每一次交手后,知道对方最新的技术水平,然后回去提升自己。
GAN 中 Generator 和 Discriminator 的关系:既是敌人,又在这种较量中促进了彼此的提高,这就是 GAN (Generative Adversarial Network) 中 Adversarial 的由来,在对抗中提升。这不由得让人想到大自然中的天敌促进物种进化,如下图所示,蝴蝶在躲避天敌的捕食中,从开始外表变成棕色像枯叶,到后来都有了叶子的纹路,一步步进化成枯叶蝶,停在树上就像一片枯叶,效果逼真。
GAN 的训练目标:从简单的 Normal Distribution 采样的数据,经过 Generator,得到的输出分布 P G P_G PG 接近目标分布 P d a t a P_{data} Pdata,如下图所示。也就是说,要找到使 P G , P d a t a P_G,P_{data} PG,Pdata 的 divergence 最小 (即 D i v ( P G , P d a t a ) Div(P_G,P_{data}) Div(PG,Pdata) 最小) 的 G。
那么问题来了,要怎么计算 D i v ( P G , P d a t a ) Div(P_G,P_{data}) Div(PG,Pdata) 呢?
只要可以做采样,就可以计算 D i v ( P G , P d a t a ) Div(P_G,P_{data}) Div(PG,Pdata) 。如下图所示,从 data 里面 sample,就得到 Sampling from P d a t a P_{data} Pdata。从 Normal Distribution 中 sample,经过 Generator,得到的输出就认为是 Sampling from P G P_G PG。
为什么可以用 sampling 来代表 distribution 计算 divergence?
这与 Discriminator 的设计有关。Discriminator 就是要尽量把从 P d a t a P_{data} Pdata sample 的数据与从 P G P_G PG sample 的数据分开,这其实也可以用 Binary Classifier 做,把 P d a t a P_{data} Pdata sample 当作 class 1, P G P_G PG sample 当作 class 2,如下图所示。设计 Classifier 的目标函数 V ( G , D ) V(G,D) V(G,D) 为: P d a t a P_{data} Pdata sample 经过 Discriminator 得到的分数为 l o g D ( y ) log D(y) logD(y) ,这个分数要尽量高; P G P_G PG sample 经过 Discriminator 得到的分数为 l o g ( 1 − D ( y ) ) log(1-D(y)) log(1−D(y)) ,这个分数要尽量低。因此,训练的目标要使 V ( G , D ) V(G,D) V(G,D) 最大。而经过一番推导(详细请看 GAN 原始论文),发现这个值与 JS divergence 有关。
下面我们通过例子来直观理解 m a x V ( D , G ) maxV(D,G) maxV(D,G) 与 JS divergence 的关系。如下图所示,假如 P d a t a P_{data} Pdata 与 P G P_G PG 的 divergence 小,Discriminator 很难分辨两者,没办法准确打分,因此 m a x V ( D , G ) maxV(D,G) maxV(D,G) 小。反之,假如 P d a t a P_{data} Pdata 与 P G P_G PG 的 divergence 大,Discriminator 很容易就把两者分开了,得到的 m a x V ( D , G ) maxV(D,G) maxV(D,G) 大。
如果说训练 Discriminator的目标是使 V ( G , D ) V(G,D) V(G,D) 最大,那么 Generator 的目标就是要使这个最大可能的 V ( G , D ) V(G,D) V(G,D) 尽量小,也就是说使 P G P_G PG 和 P d a t a P_{data} Pdata 的 JS divergence 尽量小。因此,就有下图所示的 m i n m a x min\ max min max 两个符号。
当然,除了 JS divergence,还可以使用其它的 divergence,只要设计 discriminator 相应的目标函数即可。
GAN 是出了名的不好训练。下面介绍知名的一个 tip: WGAN。
很多情况下, P G P_G PG and P d a t a P_{data} Pdata 并不重叠。JS divergence 的问题:只要分布不重叠的两类, J S ( P G , P d a t a ) = l o g 2 JS(P_G,P_{data})=log2 JS(PG,Pdata)=log2,Discriminator 始终可以准确分开这两类,如下图所示。这导致 Generator 无法知道训练是否带来结果的提升,训练学不到东西。
解决方法:改用 Wasserstein distance。我们把分布想象成一个小土堆,Wasserstein distance 计算的是要把土堆(分布)A 变换到土堆 B,推土机挪动土堆 A 的距离。
对应到分布的变换距离,计算方法如下图所示。当然,你可能也想到了,这样的移动方式有很多种,选最小移动距离 (smallest average distance) 的方式就可以了。
Wasserstein distance 的好处是计算出了两个类别分布的距离大小,因此能够捕捉一点一滴的改变(或者说进化),如下图所示。这样,Generator 就可以根据结果来一点点提高。
我的思考:俗话说得好,“一口吃不成个胖子”,量变产生质变。这里 Wasserstein distance 较 JS Divergence 的改进之处正在于注重了对“量变”的及时反馈。JS Divergence 就像是一步到位,从新手一下子就到高手,可想而知这很难。而 Wasserstein distance 就是对每一阶段的努力都有一个反馈:现在是进步还是退步了,一点点地提高,最终到达目标。
使用 Wasserstein distance 的目标函数如下图所示。注意这里有一个条件 D ∈ 1 − L i p s c h i t z D \in 1-Lipschitz D∈1−Lipschitz,换言之, D D D 函数要平滑。为什么呢?如果没有这个条件设置,为了让目标函数的值最大,D 函数就会变成:对于 P d a t a P_{data} Pdata sample, D ( y ) = + ∞ D(y)=+\infty D(y)=+∞;对于 P G P_G PG sample, D ( y ) = − ∞ D(y)=-\infty D(y)=−∞。这就会和 JS divergence 一样,训练中学不到东西。而加上这个平滑限制,如下图所示,因为曲线要连续而不能剧烈变化,也就不会到 ∞ \infty ∞ ,这才是 Wasserstein distance,能让 Generator 渐进提高。
但是,即使是用了 Wasserstein distance,GAN 还是不好训练。一个内在的原因是:GAN 结构的 Generator 和 Discriminator 要棋逢对手,如果有一个没有提高,那么另一方也会停止改进,如下图所示。在实际训练中,无法保证每一次 Discriminator 的 loss 都会下降,一旦 loss 不下降,就会出现连锁反应,整个结构都不再改进。
而如果把 GAN 用在文字生成上,那训练起来更是难上加难,为什么呢?
如下图所示,我们可以把 Transformer 的 Decoder 部分看成是 GAN 的 Generator ,生成的 sequence 送入 Discriminator 中判断是不是真的文字。这里的问题是 loss 没办法做微分。为什么呢?假设 Decoder 的输入有微小的变化,因为 Generator 的输出是取概率最大的那个,输出 sequence 不变,进而 Discriminator 的输出也不变,没有变化也就算不出微分。
疑问:为什么 CNN 的 Max Pooling 可以做微分,这里用了 Max 又没法做微分了呢?
如果输入有变化,CNN 的 Max Pooling 是一种采样,y 的变化还会保留。而这里的 Max 操作是取的 y 对应的类别,y 微小的变化对于类别影响不大。这是我的分析,如果有误,请多多指教,谢谢!
除了 GAN,其它的 Generative Models 还有 Variational Autoencoder (VAE) 和 FLOW-based Model。
简单 GAN 产生的图片天马行空,可能不是我们想要的,所以要加入一些限制条件。
简单 GAN 是不需要标注的,这里的 conditional GAN 则需要一些标注,也就是说引入了有监督学习。这也好理解,既然对机器产生的数据有一定要求,肯定要有示例告诉机器应该怎么做。
具体来说,就是 Discriminator 的输入为成对数据 ( x , y ) (x,y) (x,y)。以文字生成图片 (Text-to-image) 为例,Discriminator 的训练目标是:输入为(文字,与之对应的训练图片),输出为 1;输入为(文字,生成的图片)时,输出为 0。除此之外,还需要一种 negative sample:(文字,与之不对应的训练图片),输出为 0。如下图所示:
更多应用例子:
1.Image translation (pix2pix),比如:黑白到彩色,白天景物到夜景,轮廓素描到实物图。例如:从建筑结构图到房屋照片的转换效果如下图所示,如果用 supervised learning,得到的图片很模糊,为什么?因为一个建筑结构图对应有多种房屋外形,这样训练时机器就会考虑多种情况,做平均。如果用 GAN,机器有点自由发挥了,房屋左上角有一个烟囱或窗户的东西。而用 GAN+supervised,也就是 conditional GAN,生成的图片效果就很好。
2.sound-to image:从声音生成相应的图片,比如输入水声,生成溪流图片。
3.talking head generation:静态图转动态,让照片里的人物动起来。
GAN 的妙用:unsupervised learning ,也就是接下来要介绍的 Cycle GAN。
实际中,常常有大量未标注数据,怎么利用上这部分数据呢?
有一个方法是 semi-supervised learning,只需要少量标注数据,未标注数据可以用模型标注 (pseudo label)。但是尽管是少量,还是要用标注数据来训练模型,否则模型效果不好,标注也不好。
如果一点标注数据都没有,怎么办?
你可能会说,不会吧,总可以人工标注一点数据吧?
一个例子就是 Style Transfer,例如图像风格转换,假设我们有一些人脸图片,另外有一些动漫头像,两者没有对应关系,也就是 unpaired data,如下图所示。Cycle GAN 就是为了解决这个问题。
与前面介绍的 GAN 不同,Cycle GAN 的输入不是从 Gaussian Distribution 采样,而是从 original data 采样,用 G x − > y G_{x->y} Gx−>y 生成动漫头像图片,如下图所示:
但这样生成的二次元动漫头像可能和原图不对应。而因为没有标注,又不能用 conditional GAN 来做。Cycle GAN 增加了一个 G y − > x G_{y->x} Gy−>x ,把生成的动漫图片再变换到人物图片。训练使 G y − > x G_{y->x} Gy−>x 生成的人物图片与原图尽量接近,以此达到了原图和生成动漫头像的对应,如下图所示:
此外,还可以反向训练,从动漫图片到人物图片,再到动漫图片。训练 Cycle GAN 时可以两个方向同时训练。
Cycle GAN, Disco GAN, Dual GAN:是一样的,不同研究团队在同一时间提出,因此有不同命名。
类似的应用:文字风格转换 (Text Style Transfer),比如把消极的文字都转换为积极的文字,如下图所示:
对于 Generative Model,怎么评估生成数据,比如生成图片的好坏呢?有监督学习可以和 label 比对,而 Generator 生成的图片与原来的图片相似但不相同,怎么去判断呢?
对于作业中的二次元人物头像,可以用人脸检测,看生成的一组图片中能检测出多少人脸。
对于更一般的情况,生成多种类的图片,比如有猫、狗等不同种类,可以设计一个 Image Classifier。如果概率分布集中在某个类别,说明 Classifier 对于输出的类别很确定。也就是说,Generator 生成的图片质量好。而如果概率分布平均,说明 Classifier 不太确定看到的图片属于哪个类别,Generator 生成的图片质量不佳,Classifier 都认不出这是什么。
但是这样也有一些问题,就是生成的数据可能集中在某部分区域,比如下面两种情况。
(1) Mode Collapse
看下图你可能会发现,有一个相似的头像图片反复出现。这就是 generated data 集中在某一个 real data 周围的情况,这个区域可以认为是 discriminator 的“盲区”,区域内的图片判定为真的可能性大,因此 generator “投机取巧”,反复生成这种图片。
(2) Mode Dropping
如下图所示,如果看左边的数据分布,会觉得 generated data 还不错,但其实 real data 的多样性不止于此,还有右边部分的分布。
要解决上述问题,可以把一组 generated data 输入 CNN Classifier,然后把得到的各分类概率分布取平均作为结果。如果这个平均概率分布中,各类别的分布比较平均,那就说明 generated data 有足够的 diversity。
疑问:为什么前面 Quality of Image 说要概率分布集中在某个类别好,这里 Diversity 又说要概率分布均匀好,这不是互相矛盾吗?
看 Quality of Image 时,Classifier 的输入是一张图片。看 Diversity 时,Classifier 的输入是 Generater 生成的所有图片,对所有的输出取平均来衡量。
Inception Score (IS) 就是结合了 Quality of Image 和 Diversity。Quality 高, Diversity 大,对应的 IS 就大。
而对于作业中的生成二次元人物头像图片,不能用 Inception Score,因为都是人脸图片,Classifier 都识别为一类,因此 Diveristy 不高。
解决方法:用 Frechet Inception Distance (FID)。如下图所示,分析 CNN 的输出,也是 Softmax 的输入,这些向量之间有差异。图中红色点是真实的图像,蓝色点表示生成的图像,FID 计算的是两个 Gaussian Distribution 之间的 Frechet Distance。这里做了一个假设:真实和生成的图像都是 Gaussian Distribution。
那么,有了这些衡量标准,是不是就可以衡量出 GAN 的好坏了呢?
有时生成图片的 Quality 和 FID 都不错,可是你看图片总觉得哪里不对,比如下图中第二行的图片:
和训练图片 (real data) 一对比,发现机器学到的是和训练图片一模一样!可是我们希望机器能生成新的图片,如果和训练图片一模一样,直接到训练图片集采样就好了。
应对方法:把 generated data 和 real data 计算相似度,看是不是一样。
新的问题:机器可能会学到把训练图片左右反转一下,如图中第三行图片所示,计算相似度是不同,其实还是原图片。
所以说,衡量 Generative Model 的好坏挺难的。
觉得本文不错的话,请点赞支持一下吧,谢谢!
关注我 宁萌Julie,互相学习,多多交流呀!
阅读更多笔记,请点击 李宏毅老师《机器学习》笔记–合辑目录。
李宏毅老师《机器学习 2022》:
课程网站:https://speech.ee.ntu.edu.tw/~hylee/ml/2022-spring.php
视频:https://www.bilibili.com/video/BV1Wv411h7kN