Improving Image Captioning with Conditional Generative Adversarial Nets理解

Chen C , Mu S , Xiao W , et al. Improving Image Captioning with Conditional Generative Adversarial Nets[J]. 2018.

一、前言

图像标注(image captioning)是一门综合计算机视觉和自然语言处理的深度学习研究,相比于图像分类(image classification)、目标检测(object detection)和语义分割(semantic segmentation)等任务,其更复杂,也更具挑战性。

最早,人们基于encoder-decoder提出CNN-RNN结构进行图像标注,其中CNN用作图像表示 I I I,RNN用作句子生成 G θ G_{\theta} Gθ,并使用极大似然估计进行训练 J G ( θ ) J_G(\theta) JG(θ),但存在误差累积问题。

J G ( θ ) = 1 N ∑ j = 1 n log ⁡ G θ ( x j ∣ I j ) = 1 N ∑ j = 1 N ∑ t = 1 T j log ⁡ G θ ( x t j ∣ x 1 : t − 1 j , I j ) J_G(\theta) = \frac{1}{N}\sum_{j=1}^{n}\log G_{\theta}(x^j|I^j) = \frac{1}{N}\sum_{j=1}^N\sum_{t=1}^{T_j}\log G_{\theta}(x^j_t|x^j_{1:t-1}, I^j) JG(θ)=N1j=1nlogGθ(xjIj)=N1j=1Nt=1TjlogGθ(xtjx1:t1j,Ij)

该问题的产生与图像标注任务的输出有关。图像标注任务是给定一张图片,生成一个句子,句子由 T T T个单词组成,因此句子的产生过程是基于已生成的单词去生成下一个单词。训练时,每个单词使用的都是ground-truth;而测试时,下一个单词是基于模型已生成的单词得到的。因为模型生成的单词不一定与ground-truth相同,一旦不相同,那么以后的单词都是基于错误的单词而得到(而模型的参数是通过ground-truth得到的)。单词数目越多,累积误差越大。

为解决误差累积问题,人们将强化学习整合到训练阶段,直接优化语言指标(language metrics),如BLEU、METEOR、ROUGE、CIDEr和SPICE等。这些语言指标是在一个完整的句子生成后,计算生成句子与对应ground-truth的差异(区别于CNN-RNN中的计算单词间差异)。即强化学习最大化

L G ( θ ) = E x s ∼ G θ [ r ( x s ) ] , x s = ( x 1 s , x 2 s , ⋯   , x T s ) L_G(\theta) = E_{x^s \sim G_{\theta}}[r(x^s)], x^s = (x^s_1, x^s_2, \cdots, x^s_T) LG(θ)=ExsGθ[r(xs)],xs=(x1s,x2s,,xTs)

近些年来,生成对抗网络在计算机视觉领域有着不俗的表现,不少工作将其与强化学习结合应用于图像标注任务,虽然提高了生成句子的通顺性和多样性,但却造成精度的丧失。并且,有些工作为了展现更好的结果,只在某一个语言指标上进行了度量。

二、论文模型

2.1 总架构

Improving Image Captioning with Conditional Generative Adversarial Nets理解_第1张图片
总架构如上图所示。本文将传统的CNN-RNN作为Generator,加入新的Discriminator和language evauator,组成完整的架构。不同于之前强化学习的reward,本文的reward通过加权Discriminator和language evauator得到,平衡了生成句子的准确度和通顺性。

2.2 Discriminator设计

本文分别基于CNN和RNN设计出两种不同的Discriminator,如下图所示。这两个不同的Discriminator都是基于基本的深度学习方法得到。
Improving Image Captioning with Conditional Generative Adversarial Nets理解_第2张图片
值得注意的是,Discriminator生成了三种不同的训练数据,包括一个正样本数据集和两个负样本数据集,其中正样本数据集使用ground-truth句子和匹配图片,负样本数据集使用生成句子和匹配图片,另一个使用ground-truth和不匹配图片。

2.3 算法设计

完整的算法设计见下图。
Improving Image Captioning with Conditional Generative Adversarial Nets理解_第3张图片

三、实验

值得注意的地方:

  • 线下训练时只使用不到二十分之一的数据作为测试集,线上训练更是使用所有的数据作为训练集。
  • 使用控制变量法选择合适的超参数。
  • 使用集成学习集成4个CNN-based Discriminator和4个RNN-based Discriminator,并获得最好的性能。
  • 做了大量的对比试验。

你可能感兴趣的:(Improving Image Captioning with Conditional Generative Adversarial Nets理解)