GAN的训练技巧:炼丹师养成计划 ——生成式对抗网络训练、调参和改进

目录

  • 一、模式崩溃: 生成器产生的结果模式较为单一
    • 1.1、改进训练方法
    • 1.2、改进目标函数
    • 1.3、改进网络架构
  • 二、训练缓慢:发生了梯度消失
  • 三、不收敛:训练不稳定,收敛的慢
  • 四、过拟合
  • 五、尽早发现失败
  • 六、一些训练技巧
  • 最后

生成对抗网络(GAN:Generative adversarial networks)是深度学习领域的一个重要生成模型,即两个网络(生成器和鉴别器)在同一时间训练并且在极小化极大算法(minimax)中进行竞争。这种对抗方式避免了一些传统生成模型在实际应用中的一些困难,巧妙地通过对抗学习来近似一些不可解的损失函数。
GAN的训练技巧:炼丹师养成计划 ——生成式对抗网络训练、调参和改进_第1张图片

之前我们介绍了GAN的原理:深入浅出 理解GAN中的数学原理,GAN最重要的就是找到D与G之间的纳什均衡,但是在实际中会发现GAN的训练不稳定,训练方法不佳很容易出现模式崩溃等问题,本篇将记录一些训练技巧,不一定适合你的模型,也可能有疏漏和错误,供学习参考,欢迎指正和补充。

一、模式崩溃: 生成器产生的结果模式较为单一

模式崩溃现象狭义上来说是生成器仅仅产生单个或有限的模式来欺骗鉴别器,仅仅只是为了得到最低的判别器损失D_loss,却忽视了数据集的分布,比如一个动物图像数据集,GAN在训练时候发现生成猫和狗的效果非常好,生成牛、羊、猴子等效果很差,整个G就只去生成猫狗,根本不去学习生成其他的动物图像,就会导致生成的图像单一。模式崩溃现象本质上还是GAN的训练优化问题,即使是最优秀的 GAN 研究人员也在与模式崩溃作斗争。

解决模式崩溃有很多方法,如下:

1.1、改进训练方法

  1. 小批量鉴别器(mini-batch discriminator):因为判别器每次只能独立处理一个样本,生成器在每个样本上获得的梯度信息缺乏“统一协调”,都指向了同一个方向。于是小批量让判别器不再独立考虑一个样本,而是同时考虑一个小批量的所有样本,具体实现可以看:小批量判别器如何解决模式崩溃问题。
  2. 经验重播:每隔一段时间向鉴别器显示旧的假样本,可以使模式间的跳来跳去最小化。这可以防止鉴别器变得太容易被利用,但仅限于生成器过去已经探索过的模式。
  3. 调整GAN的学习速度(学习率):通过改变这个特定的超参数来克服这个阻碍,使用较小的学习率,并从头开始训练,学习速度是最重要的超参数之一,即使不是最重要的超参数,即使是它微小变化也可能导致训练过程中的根本性变化。
  4. 特征匹配:特征匹配改变了生成器的cost function,以最小化真实图像和所生成图像的特征之间的统计差异,测量它们的特征向量均值之间的 L2 距离。
  5. 把多个属于同一类的样本进行打包,然后传递给判别网络 D。
  6. 预计反攻:生成器在更新时,不仅仅考虑当前生成器的状态,还会额外考虑K次更新后判别器的状态,综合两个信息做出最优解,即参数更新方式为采用梯度下降方式连续更新K次,提高生成器的“先见之明”,从而避免了短视行为。首先将参数更新方式改为采用梯度下降方式连续更新K次,如下:
    θ D 0 = θ D … ⋯ θ D K = θ D K − 1 + η ∂ f ( θ G , θ D K − 1 ) ∂ θ D K − 1 \begin{aligned} \theta_{D}^{0} &=\theta_{D} \\ & \ldots \cdots \\ \theta_{D}^{K} &=\theta_{D}^{K-1}+\eta \frac{\partial f\left(\theta_{G}, \theta_{D}^{K-1}\right)}{\partial \theta_{D}^{K-1}} \end{aligned} θD0θDK=θD=θDK1+ηθDK1f(θG,θDK1)
    生成器的优化目标改为: θ G = arg ⁡ min ⁡ θ G f ( θ G , θ D K ( θ G , θ D ) ) \theta_{G}=\arg \min _{\theta_{G}} f\left(\theta_{G}, \theta_{D}^{K}\left(\theta_{G}, \theta_{D}\right)\right) θG=argminθGf(θG,θDK(θG,θD)),梯度的变化改为: d f K ( θ G , θ D ) d θ G = ∂ f ( θ G , θ D K ( θ G , θ D ) ) ∂ θ G + ∂ f ( θ G , θ D K ( θ G , θ D ) ) ∂ θ D K ( θ G , θ D ) ∂ θ D K ( θ G , θ D ) ∂ θ G \frac{d f_{K}\left(\theta_{G}, \theta_{D}\right)}{d \theta_{G}}=\frac{\partial f\left(\theta_{G}, \theta_{D}^{K}\left(\theta_{G}, \theta_{D}\right)\right)}{\partial \theta_{G}}+\frac{\partial f\left(\theta_{G}, \theta_{D}^{K}\left(\theta_{G}, \theta_{D}\right)\right)}{\partial \theta_{D}^{K}\left(\theta_{G}, \theta_{D}\right)} \frac{\partial \theta_{D}^{K}\left(\theta_{G}, \theta_{D}\right)}{\partial \theta_{G}} dθGdfK(θG,θD)=θGf(θG,θDK(θG,θD))+θDK(θG,θD)f(θG,θDK(θG,θD))θGθDK(θG,θD)

1.2、改进目标函数

  1. 特征匹配:改变生成器的损失函数;
  2. 用Wassernstein距离代替JS散度;
  3. 在梯度上加入惩罚项:WGAN-GP、DRAGAN;
  4. 引入pixel级别loss,特别是在训练早期,如L1, L2等;
  5. 在损失函数上加上正则项,帮助GAN找到更多多样性;
  6. 使用均方损失( mean squared loss )替代对数损失( log loss )。

1.3、改进网络架构

  1. 使用多个生成器,简单地接受GAN只覆盖数据集中模式的一个子集,并为不同模式训练多个生成器,而不是对抗模式崩溃,一起去生成图像,这样就可以生成多样化的图像;
  2. 自注意力机制:全局信息(长距依赖)会用于生成更好的图像。

二、训练缓慢:发生了梯度消失

  1. 网络使用残差结构:自适应网络深度,同时避免梯度消失;
  2. softmax+CrossEntropy loss:通过损失函数来抵消激活函数求导后造成的梯度消失影响
  3. 使用Adam优化器;
  4. 不要把判别器训练得太好,以避免后期梯度消失导致无法训练生成器,判别器的任务是辅助学习数据集的本质概率分布和生成器定义的隐式概率分布之间的某种距离,生成器的任务是使该距离达到最小;
  5. 对于层数过深的模型,尽量避免使用全连接层。

三、不收敛:训练不稳定,收敛的慢

  1. 生成器或鉴别器损失突然增加或减少时,不要随意停止训练,损失函数往往是随机上升或下降的,这个现象并没有什么问题,遇到突然的不稳定时,多进行一些训练,关注生成图像的质量,视觉的理解通常比一些损失数字更有意义;
  2. 添加噪声:通过添加噪声有利于提高系统的整体多样性和稳定性,在真实数据和合成数据(例如由生成器生成的图像)中添加噪声;在数学领域中,这应该是有效的,因为它有助于为两个相互竞争的网络的数据分布提供一定的稳定性;
  3. 软标签或者带噪声的标签:如果真实图像的标签设置为1,我们将它更改为一个低一点的值,比如0.9。这个解决方案阻止判别器对其分类标签过于确信,或者换句话说,不依赖非常有限的一组特征来判断图像是真还是假。

四、过拟合

在GAN中,如果鉴别器依赖于一小组特征来检测真实图像,则生成器可以仅生成这些特征以仅利用鉴别器。优化可能变得过于贪婪并且不会产生长期效益;

  1. 使用正则化来避免过拟合,常用的有L1、L2两种算法,如果已经使用了,调整其参数大小;
  2. dropout:让某些神经元以一定的概率停止工作。从隐藏层神经元中随机选择一个子集临时删除掉,然后训练时没有被删除的那一部分参数更新,删除的神经元参数保持被删除前的结果,不断重复这一过程;
  3. 软标签或者带噪声的标签(同上三)。

五、尽早发现失败

  1. D的loss一直接近于0,直接宣告失败。鉴别器太强了,生成器已经无法再产生更好的假数据了,也可以认为梯度消失了,这种情况很常见因为识别真假样本通常比伪造真实样本容易;
  2. D的loss居高不下,生成的图像很模糊不清,极有可能已失败。判别网络能力不行,胡乱分辨真假,甚至把真的误认为假的,假的误认为真的,生成器无法从判别器D那里学习到东西;
  3. 观察图像发现生成出来的图像单一,发生了模式崩溃,生成网络凑巧在生成某类真样本上特别得心应手,或者,判别网络对某类样本的辨别能力相对较差,那么生成网络会扬长避短,尽量多生成这类样本;
  4. 在一定的epoch后观察图像发现生成出来的图像模糊,全是噪声,极有可能已失败,梯度更新已经开始无意义,再往下训练也不会有改善,所以不要把时间浪费在无谓,病态的梯度更新上;
  5. GAN中loss体现的是判别器的判别能力,整体变化应该是降升、降升,最后趋于稳定。降是因为判别器性能增强了,升是因为生成器生成能力变好了。

六、一些训练技巧

  1. 将图像像素值缩放在-1到1之间,tanh作为生成器的输出层;
  2. 使用Adam优化器通常比其他更好;
  3. 使用PixelShuffle和转置卷积进行上采样;
  4. 使用Batch Normalization,其能提高网络泛化能力,使用BN后还可以不用理会过拟合中的drop out和L2正则化参数选择;
  5. 在将图像输入鉴别器之前,将噪声添加到实际图像和生成的图像中;
  6. 噪声尽量使用正态分布而不是均匀分布;
  7. 梯度惩罚;
  8. 激活函数使用LeakyRelu
  9. Two Timescale Update Rule (TTUR):不同的学习率,低速更新规则用于生成网络 G ,判别网络 D使用 高速更新规则,将判别器的学习率选为0.0004,将生成器的学习率选为0.0001也许可以达到不错的效果
  10. 反转标签,故意在部分样本上颠倒黑白,这个被放过的小鬼也许能刺激GAN别一条道走到黑;
  11. 在一定情况下打乱数据集,不然会导致网络在学习过程中产生偏见;
  12. 优先级:调参>更换损失函数>调整网络结构
  13. 不要采用早停法,要相信奇迹,除非判别器损失迅速趋近于 0;
  14. 不要放弃,一些微小改动将决定你的GAN模型能否训练成功。

部分参考自:
https://arxiv.org/pdf/1606.03498.pdf
https://towardsdatascience.com/gan-ways-to-improve-gan-performance-acf37f9f59b
https://www.zhihu.com/people/xiaomizhou94/posts

最后

个人简介:人工智能领域研究生,目前主攻文本生成图像(text to image)方向

个人主页:中杯可乐多加冰

限时免费订阅:文本生成图像T2I专栏

支持我:点赞+收藏⭐️+留言

如果这篇文章帮助到你很多,希望能点击下方打赏我一杯可乐!多加冰哦

你可能感兴趣的:(文本生成图像,text-to-image,笔记,深度学习,人工智能,计算机视觉)