改进GAN除了使用更复杂的网络结构和损失函数外,还有其他简单易行的方法吗?Uber的这篇文章或许可以给你答案,将GAN与贝叶斯方法相结合,在已经训练好的GAN上增加后处理步骤即可。本文对Uber的这篇最新工作进行了简要介绍,如果对内容感兴趣还可以点击文末的原文链接阅读论文,同时文末还提供了该方法的开源代码,你可以轻松用它来提升自己的GAN模型。
更多干货内容请关注微信公众号“AI前线”(ID:ai-front)
生成对抗网络(GAN)不仅在真实感图像生成和图像恢复方面取得了令人惊叹的效果,并且由GAN生成的一幅艺术作品也售出了40万美元的价格。
在Uber,GAN有大量具有潜力的应用,包括增强机器学习模型与对抗性攻击的对抗能力,学习交通模拟器,乘车请求或随时间变化的需求模式,以及为Uber Eats生成个性化的订单建议。
GAN由两个互相对抗的部分组成,一部分是生成器,一部分是判别器。生成器学习真实数据的分布,判别器负责需要学习如何区别真实样本和生成样本(即假样本)。大多数研究都致力于改进GAN的结构和训练过程来提高其性能,例如使用更大的网络结构或使用不同的损失函数。
NeurIPS2018的贝叶斯深度学习研讨会上,Uber的一篇论文中提供了一种新的思路:调整判别器用于在完成训练后从生成器中选择更好的样本。该工作提供了一种互补的抽样方法,Google和U.C. Berkeley在判别器舍选抽样(Discriminator Rejection Sampling,DRS)的研究与此方法也具有相同的思路。
Uber这篇工作以及DRS方法的核心思想可归纳为,如何使用已经训练好的判别器的信息来从生成器中选择样本,以保证这些被选择的样本尽可能符合真实数据的分布。通常,在训练完成后判别器就没有什么用了,因为在训练过程中会将判别器学到的知识编码到生成器中。然而,生成器往往是不完美的,判别器同时也会含有一些有用的信息,所以上述使用判别器信息来提升已经训练好的GAN的方法是值得一试的。Uber的研究团队使用了Metropolis-Hastings算法对分布进行抽样,并将采用这种方法得到的模型称为Metropolis-Hastings GAN,即MH-GAN。
GAN的训练过程通常被理解为两种条件之间的博弈,生成器需要尽可能让判别器产生误判的概率最大化,而判别器则需要尽可能的对真1z实数据和生成数据进行良好的区分。图1展示了这个过程,生成器使得函数值向极小值方向移动(橙色线条),而判别器则向极大值方向移动(紫色线条)。训练结束后,向生成器输入不同的随机噪声可以得到很方便得到生成样本。如果可以训练一个完美的生成器,那么生成器最终的概率密度函数pG应与真实数据的概率密度函数相同。然而,许多现有的GAN无法很好地收敛到真实数据的分布,因此从这种不完美的生成器中抽样会产生看起来不像原始训练数据的样本。
这种pG的不完美让我们想到另一种分布情况:判别器对生成器隐含的概率密度。这种分布被称为pD,并且它往往都很接近真实的数据分布pG。这是因为训练判别器是一种比训练生成器更简单的任务,因此判别器很有可能包含可以用于校正生成器的信息。如果我们有一个完美的判别器D和一个不完美的生成器G,使用pD而不是pG作为生成的概率密度函数等价于使用一个新的生成器G’,并且这个G’是可以完美地模拟真实数据分布的,如图一所示:
即使pD的分布可能与数据更匹配,但若想利用其得到样本数据并不像直接使用生成器那样直接。幸运的是,我们可以使用抽样算法从分布中产生样本,一种是舍选抽样法(Rejection Sampling,也被称为Acceptance-Rejection Sampling),一种是马尔科夫链蒙特卡洛法(Markov Chain Monte Carlo,MCMC)。这两种方法都可以作为一种后处理方法来提高生成器的输出;之前的判别器舍选抽样法(Discrimitor Rejection Sampling,DRS)借鉴了舍选抽样法的思路,而MH-GAN则采用了Metropolis-Hastings MCMC方法。
很多实际问题中,真实分布p(x)是很难直接抽样的的,因此,我们需要求助其他的手段来抽样。既然 p(x) 太复杂在程序中没法直接抽样,那么我们可以设定一个程序可抽样的分布 q(x) 比如高斯分布,然后按照一定的方法拒绝某些样本,达到接近 p(x) 分布的目的,其中q(x)叫做候选分布(Proposal Distribution)。
具体操作如下,设定一个方便抽样的函数 q(x),以及一个常量 k,使得 p(x) 总在 kq(x) 的下方。(参考上图)
x 轴方向:从 q(x) 分布抽样得到 a。
y 轴方向:从均匀分布(0, kq(a)) 中抽样得到 u。
如果刚好落到灰色区域即u \u0026gt; p(a),则拒绝,否则接受这次抽样。
重复以上过程便可得到p(x)的近似分布。该方法两大挑战分别是:
k的值通常是人为经验设置的,无法确定一个准确的值。若k值设置的过大可能导致拒绝率很高,增加无用计算;若k值过小则有可能找不到正确的p(x)分布。
合适的q(x)分布通常很难找到。
在GAN中,pD即为目标分布对应上述p(x),pG为现有的分布对应上述q(x)。所以在GAN中使用该方法的难点主要来源于k值的确定,或因k值太小而无法正确抽样,或因k值过大而在高维空间中产生大量的计算。为了解决样本浪费问题,DRS启发式地增加了一个γ调整判别器分数,使得判别器D即使是完美的情况下,从分布中产生的样本仍能够与真实样本存在差异。
Uber的这篇工作使用了Metropolis-Hastings(MH)方法,这是马尔科夫链蒙特卡洛法一类方法中的一种。这一类方法被最初是作为舍选抽样法在高维空间中的代替而发明的,它们通过从候选分布中多点抽样得到一个尽可能复杂的概率分布,然后再对这个概率分布进行抽样。MH包含两步,第一步是从候选分布中(例如,生成器)选择K个样本,然后从K中依次选择一个样本,决定是接受当前样本还是根据接受规则保留先前选择的样本,如图3所示:
其中,D是判别器分数,由以下公式得到
K是一个超参数,对其调整可以在速度和置信度之间做出权衡。对于一个完美的判别器K趋近于无穷,即D的分布完美的接近了真实数据分布。
噪声样本被独立地输入生成器,经过K次生成得到可以符合MH选择器条件的状态链。独立的链被用于从MH-GAN的生成器G’中获取多样本。
对于MH算法,由于初始点的不确定性,大部分情况下算法会经过一段长时的预烧期才能开始有效的优化过程,即在开始接受第一个数据点之前会拒绝很大一部分数量的数据点。为了避免这种情况,本文对如何初始化状态链的方法进行了详细的介绍。在清理和初始化每一条状态链时,可以使用真实数据的采样结果对状态链进行优化。在遍历了整个状态链之后,如果没有一个数据被接受,MH-GAN会从生成样本中重新开始抽样,从而确保真实数据中的样本不被输出。值得注意的是,MH-GAN不需要真实的样本进行初始化,只需要它所对应的判别器分数即可。
实际上,得到完美的D是不可能的,但是通过校准步骤可以达到相对完美的程度。另外,完美判别器的假设也不一定就真如它看起来那么好用。因为判别器仅对生成器和最初的真实数据进行评价,它只需要对来自生成器和真实数据分布的达到精确判别就可以。在一般的GAN训练中,一般不需要严格的要求判别器D的值达到一个确定的边界。但是MH算法需要从概率密度比方面对这个值进行良好的校准,从而得到正确的接受比。MH-GAN使用10%的训练数据作为随机测试集,使用保序回归的方法对判别器D进行调整。
Uber在论文中使用了一些小例子对MH-GAN和DRS方法进行了比较,其中真实数据来源于四个单变量的高斯模型的混合结果。通过pG的概率密度图可以看出普通的GAN存在的通病,它们的生成结果都缺失了一种模式(如图4所示)。但是,不使用γ校正DRS和MH-GAN则能良好的还原混合模型,而使用γ进行调整的DRS不能还原原始分布。然而,与使用γ进行调整的DRS方法相比,不使用γ的DRS方法在第一次接受之前抽样的数量增加了一个数量级。
这部分内容主要展示了MH-GAN在真实数据上的效果,分别测试了选取使用了梯度惩罚的DCGAN和WGAN作为基础GAN的结果。在图6的表格中展示了校准后的MH-GAN的感知分数(Inception Socre)。
感知分数会完全忽略真实数据而只是用生成的图像进行评价,它需要将生成图像传入在ImageNet上预训练好的感知分类器中,感知分数会对输入图像属于某个详细类的置信度和预测类别的多样性进行测量。尽管感知分数存在缺陷,但它仍被广泛用于与其他工作进行比较。
基本上校准后的MH-GAN比其他方法都可以取得更好的效果,但是在整个训练过程中这种优势并不是一直存在的。对于这种情况的一个解释是,对于某一轮的迭代,判别器的分数与理想的判别器分数存在巨大差异,从而导致了接受概率缺乏准确性。
MH-GAN是一种提升GAN生成器的简单方法,该方法使用Metropolis-Hastings算法作为一个后处理步骤。在模拟数据和真实数据上MH-GAN都表现除了超越基础GAN的效果,与最近提出的DRS方法相比MH-GAN也更具有优势。目前该方法仅在较小的数据库和网络上进行了验证,下一步Uber计划将该方法用于更大的数据库和更先进的网络。将MH-GAN方法扩展到大规模数据库和GAN的途径是非常简单粗暴的,因为仅需要额外提供判别器分数和生成器产生的样本就可以!
此外,使用MCMC算法提升GAN的思想也可以扩展到其他更高效的算法上,例如汉密尔顿蒙特卡洛方法。如果想获取关于MH-GAN的更多细节和图表可以阅读论文:Metropolis-Hastings Generative Adversarial Network,如果想复现该工作,Uber提供了该方法基于Pytorch的开源代码。
阅读英文原文:https://eng.uber.com/mh-gan/