GAN生成对抗网络是最近比较火的深度学习技术,这里记录一下自己学习GAN的笔记,以便日后复习。本文不会给出每种模型的细节,只是提一下主要区别和trick,同时给出相关参考链接。
GAN可以做很多事情,如自动生成动漫人物头像;做pix2pix(image2image)的工作,如给黑白图片上色,基于模糊图片生成高清图片,素描生成真实照片,将风景画“莫奈化”等;也可以用于最近比较火的AI换脸,GAN掉马赛克(修补图片缺失的部分),用于生成万能指纹等。
先来复习一下自编码器,自编码器结构如下图所示:
当自编码器训练好后,输入一个随机的code向量到Decoder中,理论上Decoder会生成一张图片。
Auto-Encoder的问题:Decoder的输入向量需要是一张图片经过Encoder生成的,否则无法任意生成图片(随机的输入向量生成的图片不稳定,很难人为构造比较好的隐藏向量),相当于Decoder只认识学习过的图片。
在Auto-encoder的基础上做了改进,encoder会生成两组向量,一组代表均值,一组代表标准差,迫使其生成的隐含向量能够粗略的遵循一个标准正态分布,这样就可以构造正态分布的向量,通过Decoder来生成稳定图片。(均值和标准差都是假设的,学习过程中会往假设方向更新)在标准差上叠加一组noise(高斯分布的)。
下面定义了VAE的loss=reconstruction_loss+KL_loss,其中KL_loss(latent loss)可以参考https://zhuanlan.zhihu.com/p/22464760的公式推导,假设前提是sample是从正态分布中采样的。
loss = reconstruction_loss + latent_loss
reconstruction_loss = mean(square(generated_image - real_image))
# latent_loss = KL-Divergence(latent_variable, unit_gaussian)
# z_mean(均值) and z_stddev(标准差) are two vectors generated by encoder network
latent_loss = 0.5 * tf.reduce_sum(tf.square(z_mean) + tf.square(z_stddev) - tf.log(tf.square(z_stddev)) - 1,1)
samples = tf.random_normal([batchsize,n_z],0,1,dtype=tf.float32)
sampled_z = z_mean + (z_stddev * samples)
VAE比auto-encoder改进了很多(auto-encoder只会记住看过的图片),但是VAE只是更好地学习了对照,而没有真正地学会生成。
GAN最大的问题是难以训练,DCGAN在GAN的基础上添加了一些trick,拓展了维度,并且训练成功:
1.去掉了G网络和D网络中的pooling layer。
2.在G网络和D网络中都使用Batch Normalization
3.去掉全连接的隐藏层
4.在G网络中除最后一层使用RELU,最后一层使用Tanh
5.在D网络中每一层使用LeakyRELU。
•G网络:100 z->fc layer->reshape ->deconv+batchNorm+RELU(4) ->tanh 64x64
•D网络(版本1):conv+batchNorm+leakyRELU (4) ->reshape -> fc layer 1-> sigmoid
•D网络(版本2):conv+batchNorm+leakyRELU (4) ->reshape -> fc layer 2-> softmax
LSGAN修改了原始GAN的loss function,原始GAN是log损失,对应的优化目标是KL散度和JS散度;而LSGAN则是L2损失,对应Pearson散度。
我们选择 b=1 表明它为真实的数据,a=0 表明其为伪造数据。 c=1 表明我们想欺骗辨别器 D。
但是这些值并不是唯一有效的值。LSGAN 作者提供了一些优化上述损失的理论,即如果 b-c=1 并且 b-a=2,那么优化上述损失就等同于最小化 Pearson χ^2 散度(Pearson χ^2 divergence)。因此,选择 a=-1、b=1 和 c=0 也是同样有效的。
在LSGAN之前已有WGAN通过使用Wasserstein距离度量替代JSD距离度量,解决了GAN难以训练的问题。但WGAN训练较慢,而且需要一些特别的剪枝等操作辅助。LSGAN则比WGAN要快,又比原始GAN稳定。
选择最小二乘Loss做更新有两个好处, 1. 更严格地惩罚远离数据集的离群Fake sample, 使得生成图片更接近真实数据(同时图像也更清晰) 2. 最小二乘保证离群sample惩罚更大, 解决了原本GAN训练不充分(不稳定)的问题。
但有好处的同时也会带来问题:LSGAN对离离群点的过度惩罚, 可能导致样本生成的”多样性”降低, 生成样本很可能只是对真实样本的简单”模仿”和细微改动.
参考:LSGAN-最小二乘GAN, 最小二乘GAN:比常规GAN更稳定,比WGAN收敛更迅速
在标准的GAN中,生成数据的来源一般是一段连续单一的噪声z,这样带来的一个问题是,Generator往往会将z高度耦合处理,我们无法通过控制z的某些维度来控制生成数据的语义特征,也就是说,z是不可解释的。我们希望每个手写数字可以分解成多个维度特征:代表的数字、倾斜度、粗细度等。
CGAN:G网络的输入在z的基础上多一个输入y,D网络的输入在x和G(z,y)的基础上也多一个y。(例如输入一个数字,然后输出对应数字的图片。)
Positive sample:(c, x)
Negative sample: (c, G(z,c))和 (cfake,c_fake,x)
论文:InfoGAN
代码:openai InfoGAN
公式推导解析:infoGAN公式推导
•InfoGAN:G网络的输入除了z之外不再使用标签y,而是换成一个latent code c。在不加限制的情况下,网络会自动忽略掉c,直接训练出G(z,c)=G(z)的generator,所以添加了G(z,c)和c之间的互信息作为损失函数的一部分,让generator自己学习到潜在编码c。
•想估计互信息I比较难,因为真实的分布P(c|x)无法获得,所以使用辅助分布Q(c|x)来估计(逼近) P(c|x) ,对于c,如果是categorical latent code,可以使用softmax的非线性输出来代表Q(c|x);如果是continuous latent code,可以使用高斯分布来表示。
其中Q网络和D网络共享一个网络,基于X输出c,只是在最后一层分开独立输出,H(c)是c的信息熵
其中:L1使用蒙特卡洛来逼近
•相比CGAN:
•D网络的输入只有x,不加c。
•D网络除了二分类的输出,还输出一个c(Q网络)
•GAN的问题:
在最优判别器的条件下,原始GAN定义的生成器loss等价变换为最小化真实分布与生成分布之间的JS散度 ,JS散度越小,生成器性能越好。然而JS散度只在两个分布有重叠的时候起作用,当两个分布没有重叠部分时,JS散度是常数log2,导致梯度为0,梯度消失。
1.等价优化的距离度量KL散度和JS散度不合理
2.生成器随机初始化后生成的分布很难与真实分布有不可忽略的重叠
WGAN提出了一种新的度量方法:Wasserstein距离(又叫Earth-Mover)
Wasserstein距离相比KL散度、JS散度的优越性在于,即便两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。
论文(patchGAN):Image-to-Image Translation with Conditional Adversarial Networks
在image-to-image任务基础上,有了效果更好的cycleGAN,而且cycleGAN使用的数据不再是成对的数据,而是使用unpair的数据进行训练即可,通用性更强(不过在论文里,作者还是主要使用的有一些pair性质的数据集做的实验)。
CycleGAN就是在原始GAN的基础上做了一个逆向过程,即X转换为Y后,再从Y转换回X,Loss也是在原始GAN loss的基础上多了逆向的GAN loss,此外还添加了针对X和Y的L1 loss:
Adversarial Loss(这是正向,还有一个逆向):
详细解析见CycleGAN解析博客
lua代码见https://github.com/junyanz/CycleGAN
论文见Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
深度解密换脸应用Deepfake
相关博客:https://danieltakeshi.github.io/2017/03/05/understanding-generative-adversarial-networks/
与原始minimax版本相比,不饱和的版本,根据图可以看出,在G的性能不高时能够更快地收敛,可以加快G的训练速度。
mode collapse 模式崩塌
从泛化性到Mode Collapse:关于GAN的一些思考
collapse发生原因
总结来说,一个概率分布是往往是复杂的,可能有多个mode(峰值),GAN的generator在训练的过程中容易往其中一个mode逼近,导致生成的结果单一,即改变了generator输入的latent z,但是生成的结果G(z)没有发生变化。
解决办法: