RL-GAN Net -- 首个将强化学习与GAN结合的网络

RL-GAN Net

  • 引言
  • 背景知识
    • 强化学习
    • 生成对抗网络
  • 核心思想
    • 基本框架
    • 损失函数
    • 网络结构
  • 实验结果对比
  • 总结

引言

作者首次将强化学习和生成对抗网络结合起来,用于点云数据生成。通过控制GAN将有噪声的部分点云数据转换为高保真度的补全形状。由于GAN不稳定且难以训练,作者通过:(1)在隐空间表示上训练GAN来解决该问题,该表示的维数与原始点云输入相比减小了;(2)使用RL agent来找到更优的输入给GAN以生成最适合当前不完整点云输入的形状的潜在空间表示。这是首次尝试训练一个RL agent来控制GAN,它有效地学习了从GAN的输入噪声到点云潜在空间的高度非线性映射。原文链接:RL-GAN-Net。代码链接:RL-GAN Net-Github
让我们来看看效果:
RL-GAN Net -- 首个将强化学习与GAN结合的网络_第1张图片
如图所示,最左边一列是groundtruth图像,Pin表示输入的缺失点云图像,AE表示自编码器产生的补全输出图像,RL-GAN-Net表示作者的方法得到的结果,最后一列是作者的方法与groundtruth重合得到的结果。接下来看一下具体是如何实现的。

背景知识

本文将强化学习(Reinforcement Learning)和生成对抗网络(Generative Adversarial Nets)结合起来,提出了一种新的RL-GAN-Net方法,巧妙的利用RL agent获得生成器的优质输入,并利用判别器获得完整的点云图像输出。是一项具有开创性意义的工作。我们先看一下RL和GAN的基础知识。

强化学习

RL-GAN Net -- 首个将强化学习与GAN结合的网络_第2张图片
强化学习是机器学习中的第三个分支(前两个分支是监督学习和无监督学习)。如图所示,强化学习可以分为两大类,有模型学习(model based)和无模型学习(model free),其中常用的是无模型学习,即马尔科夫决策过程中的策略和奖励函数都未知的情况。其次,无模型方法又分为基于策略优化(policy based)和基于值优化(value based)的方法,本文用到的强化学习方法是策略优化方法和值优化方法的结合DDPG,即深度确定性策略梯度(Deep Deterministic Policy Gradient)。DDPG方法是从最基础的PG(policy gradient)演变而来的,经历了PG→DPG→DDPG。

生成对抗网络

生成对抗网络(GAN)是Goodfellow在2014年提出一种生成模型,目的是生成与原始数据尽量相近的生成数据。
GAN网络主要由两个网络构成,生成网络G和判别网络D,生成模型G的思想是将一个噪声包装成一个逼真的样本,判别模型D则需要判断送入的样本是真实的还是假的样本,即共同进步的过程,辨别模型D对样本的判别能力不断上升,生成模型G的生成能力也不断上升。生成器的期望是将所生成的数据送入判别器后,判别器将其判别为真实数据;判别器的期望是将所有的生成数据和原始数据区分开。二者在博弈中共同提升性能。

核心思想

基本框架

RL-GAN Net -- 首个将强化学习与GAN结合的网络_第3张图片
该图是RL-GAN的前向传播示意图。图中E表示自编码器(AE)的编码器(encoder),E-1表示AE的解码器(为了和GAN的判别器作区分,这里的符号采用的是E-1),G表示GAN的生成器,D表示GAN的判别器。灰色的部分表示全局特征向量GFV(global feature vector),它的意义是输入的点云图像在潜变量空间中的低维表示,读者可以把它理解成一种降维方式。
首先,输入图像Pin被编码为潜变量表示GFV,分布送入RL agent和AE的解码器E-1,被送入AE的解码器的GFV经过解码得到输出结果PAE。而被送入到RL agent的GFV经过强化学习方法训练得到潜变量z,被送入到生成器,得到新的GFV(clean GFV),在经过解码器得到PGAN。后面的部分将在下文中提到,咱们先看损失函数。

损失函数

AE的损失函数很简单,是一个拟合点云图像输出的chamfer distance:
在这里插入图片描述
其中,P1和P2分别是输入和输出的点云图像空间。可以理解为先遍历P1空间中的所有点并分别计算与其最接近的点的距离,再遍历P2空间中的所有点并分别计算与其最接近的点的距离。损失函数的目的是是该距离之和最小化。
再来看GAN的损失函数,分成了三部分:
在这里插入图片描述
第一部分是输入图像Pin和GAN生成的点云图像的chamfer distance。
在这里插入图片描述
第二部分是生成器得到的clean GFV和AE的编码器得到的noisy GFV的二范数的平方,其中z是RL agent生成的潜变量。该损失函数存在的意义是保证GAN生成的点云图像在语义上尽量正确。换句话说,虽然GAN在细节上表现比AE好,但有时会输出错误语义的点云图像(实验部分有提到),加入这一损失能保证输出图像的语义正确性。
在这里插入图片描述
第三部分就是一个判别器的判别损失,最小化该损失,即最大化D的期望。

网络结构

RL-GAN Net -- 首个将强化学习与GAN结合的网络_第4张图片
这是RL-GAN的整体结构,除了前图中提到的部分,这里还给出了训练RL agent的部分。需要注意的是,网络中的AE和GAN是预训练好的,在训练RL-GAN时是直接使用的。Reward是RL中的奖励函数,agent希望在某个state采取一个最好的action获得更大的reward,奖励函数通过对损失函数取相反数获得,很明显,最大化奖励函数就等价于最小化损失函数。如图所示,RL agent在得到奖励函数和状态序列后,通过DDPG策略不断优化policy π,最终,训练好的π会在每个state得到一个确定的action输出。在该任务中,缺失点云图像是通过在原始点云图像中找到一个种子(seed),并以该seed为球心,以一定的半径确定一个球形,在原始图像中挖去这个球形得到的。RL的优化过程就是agent找到这个seed并确定半径的过程。
图中GFV reward来自clean GFV和noisy GFV,Discriminator reward来自判别器,Chamfer reward通过比较输入图像和输出图像获得。
RL-GAN Net -- 首个将强化学习与GAN结合的网络_第5张图片

除了上述RL-GAN-Net,论文还提出了一种混合的方法,见上图。AE和GAN获得输出图像后,又分别进入了AE的编码器E,将输出图像转化为GFV,同时送入判别器D中,由D决定哪种方法得到的输出更好,并输出较好的图像。文中提到,当点云缺失为20%时,AE产生的结果并不比RL-GAN的结果差,甚至更优。

实验结果对比

RL-GAN Net -- 首个将强化学习与GAN结合的网络_第6张图片
实际上,在衡量指标这一方面,作者选取的方法并不够好,也许是因为该领域研究者比较少,所以并没有很合适的衡量指标。从图中a可以看到,当缺失数据达到70%的时候,AE和RL-GAN的chamfer distance没有很明显的区分,而我们从图像上直观感受到AE的结果明显要差一些。为了凸显自己方法的优势,作者又做了分类实验,准确率明显高于AE(图b)。c图展示的是随着缺失比例增加,三个损失函数数值的变化,其中变化最明显的是GFV loss。

总结

总的来说这是一篇不错的文章,将GAN和RL两大最热门的方法结合起来是一个很具创新性的思想,对点云的补全效果也非常不错。不足在于没有找到更好的度量指标,将RL-GAN的优势以定量的形式展示出来。
目前这篇paper的代码已经公开了。在启发性方面,既然RL可以和GAN结合做点云图像生成,那一定也可以做其他工作,也许以后RL和GAN结合会成为一种趋势。当然具体能实现什么,这就要看大家的脑洞了。

你可能感兴趣的:(深度学习,人工智能,生成对抗网络,强化学习,机器学习)