生成对抗网络论文《Contextual RNN-GANs for Abstract Reasoning Diagram Generation》阅读笔记

生成对抗网络(GAN)自2014年横空出世后,成为深度学习领域的又一有力科研利器,各大顶会都有他出没的身影。由于之前尝试使用过RNN做文本正则化(Text Normalization),因此在看生成对抗网络时自然地偏向于关注能处理序列的RNN与GAN的结合。闲话不说,论文标题《Contextual RNN-GANs for Abstract Reasoning Diagram Generation》一目了然,表明作者是使用能处理上下文信息的RNN与GAN的结合模型,来完成抽象图的推理生成任务。

要点

  • 文章概况
  • 创新点
  • 核心模型:Context-RNN-GAN
    • Loss顺便一提
    • 训练过程
    • Loss改进
  • 其他模型
  • 特征表示(Feature Representations)
  • 实验与结果
    • 数据集
    • DAR
    • Next Frame Generation
  • 个人总结

文章概况

论文前面部分大致探讨了抽象推理这一任务的内容,具体地则是提出了一个称为 图解抽象推理(DAR) 的任务,即给定一组有序的问题序列图然后要求产生一个新的图,产生的新图要符合之前问题序列的规律。而论文要实现的就是通过带有RNN的GAN模型生成符合规律的问题解图。
生成对抗网络论文《Contextual RNN-GANs for Abstract Reasoning Diagram Generation》阅读笔记_第1张图片
如上图,任务需要通过前面5张图找到规律,生成符合规律的第6张图。图示中的Ground-truth即为正确的答案,而Generation则是GAN模型得到的问题解图。
使用这样的RNN-GAN模型还可以用于生成真实世界的图像和视频(这是计算机视觉和深度学习的最新研究方向),因此论文又说明了可以完成 下一帧图像的预测 (next frame generation) 这一特定任务。
在这里插入图片描述
图示中的前5帧图像为时刻1~5连续移动变化的图像,倒数第二张图像是第6时刻运动的真实图像,而最后一张图是GAN产生的运动图像。
论文完成了以上两个任务的实验,并表明使用RNN-GAN网络模型可以产生较好的结果。

创新点

论文Introduction章节中总结了论文的创新点如下:

  • A new abstract reasoning and image generation dataset.
  • A novel temporally contextual, RNN-based adversarial generation model.
  • A novel feature representation of temporally adjacent images using a Siamese Network.
  • Strong performances on two datasets (DAR and MNIST videos), competitive with 10th-grade humans and comparable state-of-the-art.

在我看来,最重要的就是率先提出了使用RNN和GAN结合的模型,可以学习到输入序列的信息、并生成与输入序列相同形式的数据(在论文中就是生成图像)。其中RNN可以学习序列潜在信息,GAN用来生成目标数据。

核心模型:Context-RNN-GAN

Loss顺便一提

在提出使用GAN模型时,文章提了一下使用GAN的好处:因为使用GAN模型时计算Loss是通过生成模型的误差Loss(G)与判别模型的误差Loss(D)的求和所得,这其中涉及到GAN设计的初衷——是一个最小最大化计算Loss的优势(是一个G和D的博弈平衡问题,G总想欺骗D,D总想分辨G产生的数据和真实数据),而如果不使用GAN,使用L1范式Loss在实验中可能会丢失actual diagram信息,使用L2范式Loss在实验中产生的图像会有叠加造成的杂乱无章效果。
关于核心模型Context-RNN-GAN,简单来说就是生成器D和判别器G都使用RNN来处理。

训练过程

生成对抗网络论文《Contextual RNN-GANs for Abstract Reasoning Diagram Generation》阅读笔记_第2张图片
上图所示,即是文章所提出的Context-RNN-GAN模型在处理序列数据时的过程。
对于输入序列X=(x1,……xt-1)的训练,每次训练为不断重复先训练D,再训练G的过程:

  • 训练判别器D过程中,如果训练生成器G产生的伪数据,则每次输入给判别器D的是前面所有时刻的数据x1 ~ xt-1,以及G在t时刻生成的伪数据yt,此时训练的标签输出为0(代表假);如果训练真实数据,则每次输入给判别器D的是前面所有时刻数据x1 ~ xt-1,以及t时刻真实数据xt,此时训练的标签输出为1(代表真)。
  • 训练生成器G过程中,在t时刻G输入数据xt,产生输出yt,将G网络与D网络串接(即G的输出作为D的输入,与图示一致),此时输出G+D的网络的输出使用标签1(代表真,目的是使G生成的数据欺骗D)来计算Loss并反向传播训练(保持D的参数不变)。

Loss改进

训练判别器D时,Loss的计算为:
生成对抗网络论文《Contextual RNN-GANs for Abstract Reasoning Diagram Generation》阅读笔记_第3张图片
训练生成器G时,Loss的计算为:
在这里插入图片描述
在计算G的Loss时实验发现最小化上述Loss可能导致训练不稳定(作者说是因为生成器生成能够欺骗鉴别器的图像与它想要正确建模的xt + 1完全不同导致,此处不甚理解),因此提出了加入一个正则项Lp来调和,关于Lp计算如下:
生成对抗网络论文《Contextual RNN-GANs for Abstract Reasoning Diagram Generation》阅读笔记_第4张图片
从而得到生成器G真正的Loss为:
在这里插入图片描述

其他模型

为了进行实验效果比较,论文又列出了以下一些模型:

  • RNN-GAN:简化版的Contextual RNN-GANs,将判别器由RNN替换成全连接网络(Fully Connected Network),由此原本判别器函数由D(x1,…,xt, yt)变为D(yt)。
  • RNN: 仅使用RNN作为模型,使用的是GRU结构。Loss计算使用L1和L2范式Loss分别测试。
  • Feed-forward Baseline:单纯的全连接多层感知器(即全连接网络),训练时使用前4张图片作为输入,预测第5张图片;测试时使用第2到第5张图片来预测第6张图片。

特征表示(Feature Representations)

为了提升实验效果,论文中提出了几种提取图片数据特征作为网络输入的方法,即所谓的 “image embedding”:

  • Raw Pixels:不加改动,只是将所有图片调整为128x128的原始像素,作为输入
  • Histogram of Oriented Gradients:定向梯度直方图
  • Autoencoder:自动编码器
  • Pretrained CNN:预训练的CNN
  • Fine-tuned AlexNet model :精心调整的AlexNet模型
  • Shallow CNN:浅层CNN
  • Siamese CNN:论文发现实验效果最好的一种CNN图像嵌入

实验与结果

数据集

论文作者们从IQ测试书籍和在线资源处收集了大约1500个训练问题(其中包括约15000张图片,即每个问题对应10张图片),然后使用transform方式(rotation和mirror reflections)扩展数据集为原始数据的8倍(即约12000个问题,120000张图片)。测试过程中使用100个问题,5张图片作为问题的输入,5张图片作为问题答案的候选,这样做方便实验效果量化评估。

DAR

模型实验参数参照论文,在此不赘述。实验结果如图:
生成对抗网络论文《Contextual RNN-GANs for Abstract Reasoning Diagram Generation》阅读笔记_第5张图片
与之对比的是人类的测试效果,论文作者们请到了两组人来做实验,一组是大学本科生,另一组是10年级的中学生:
生成对抗网络论文《Contextual RNN-GANs for Abstract Reasoning Diagram Generation》阅读笔记_第6张图片
可以发现Contextual RNN-GAN的最高准确率35.4% 与10年级中学生的平均准确率36.67% 非常接近,但是与大学本科生平均准确率44.17% 仍有一定差距。
然后给出了部分生成结果较好的实验测试图:
在这里插入图片描述

Next Frame Generation

关于下一帧预测的实验,给出了部分实验效果图:
生成对抗网络论文《Contextual RNN-GANs for Abstract Reasoning Diagram Generation》阅读笔记_第7张图片

个人总结

这篇论文最大的的贡献就是提出了使用RNN+GAN的方式来处理序列化输入的数据来产生有效的(或者说是接近逼真答案)目标输出。当处理的数据不是单个图片、文本、音频时,可以很好的借鉴RNN的记忆过去优势来解决问题,这是RNN的长处;使用GAN这种对抗——生成的训练方式,可以更好地产生接近于真实数据分布的目标数据。所以,RNN+GAN其实是在特定场景下(处理序列时)两种模型的长处互补,达到精益求精、更上一层楼的效果!

你可能感兴趣的:(深度学习)