用自我监督学习克服语言先验的视觉问答

       大多数VQA模型都有语言先验的问题,语言先验:归根结底就是说模型过度依赖question,或者完全凭借问题就可以给出答案,而不会结合我们给的图片。可能这在训练集上会有较好表现,但是模型在OOD的数据(训练集分布外的数据)上取得较差的泛化性能。我认为这有点类似与人类的惯性思维,就比如我们在考科目一的时候,通常都会刷几遍题库,到后来基本上只看关键的几个字就能知道选哪个,也在模拟题上取得了不错的分数(训练数据集)。但是当我们考试的时候一旦题意与原题有一点差别,我们就可能会做错(域外数据集)。综上所述,解决语言先验的大体思路就是怎样减少对question的依赖性,加强对image的依赖性。这篇文章就用自监督的方法去平衡数据偏差,使得VQA模型克服语言先验。 

解决的问题及贡献

●本文提出了一个QICE(问题-图像相关性评估)辅助任务,并加入到自监督学习框架中来帮助VQA模型克服语言先验问题。

●首先通过设置一个自动生成的问题-图像相关性标签c(c=0/1,表示无关/相关),将部分有偏差的数据自动转化为平衡数据,从而得到一个平衡的数据集(不相关和相关的实例个数相同)。然后基于此数据集来训练自监督的辅助任务,来判断问题-图像是否相关(二分类任务),并给出预测分数。

用自我监督学习克服语言先验的视觉问答_第1张图片

网络框架

用自我监督学习克服语言先验的视觉问答_第2张图片

●图(a)表示一个最基本的VQA模型流程,我们输入一个问题-图像(Q-I)对,VQA模型会给出所有答案的预测概率,再将预测答案与真实答案对比,计算出损失Lvqa。

●图(b)表示我们之前提到的平衡有偏差数据的处理方法,具体操作就是我们从原数据集取出若干Q-I对,不管它们是否相关,全部将其相关标签c设为1(即相关)。然后将各个Q-I对的图像image丢掉,从图像空间中再随机取一个image,与原问题组成一个新对,我们再将这些新对的标签c全部设为0。将上述我们操作过的所有Q-I对作为一个新的数据集,用来训练。以上操作可由计算机自动完成,无需人工注释。

●图(c)表示如何处理相关和不相关对。在训练过程的前12轮我们只处理相关对,之后才加入不相关对一起训练。Lvqa表示相关对的损失(VQA模型损失),Lqd表示不相关对的损失(问题依赖损失)。我们在训练过程中会想要最大化相关对的答案概率,最小化不相关对的答案概率。(具体原因在下一节详述)从而最小化损失Lvqa和Lqd,最后实现最小化模型整体损失Lself。

主要公式

(1)式表示模型预测第i个实例的答案概率。(2)式表示最小化交叉熵损失。【Ai表示正确答案位置】(3)式表示多标签交叉熵损失。δ( )表示sigmoid函数,ti表示第i个例子每个答案的软目标得分(soft target score,对于每个候选答案,根据它们与标签答案的相似性计算一个软目标得分),我们在训练的时候只需要选一个损失函数作为VQA模型部分的损失函数即可。

用自我监督学习克服语言先验的视觉问答_第3张图片

公式(5)是我们模型总的训练损失。通过最小化VQA损失 来优化自监督训练损失函数,模型中是使用Lself作为反向传播更新参数的损失函数。VQA模型的预测概率P(A|Q,I)可视为(I,Q)为相关对的置信度。概率越大,匹配度越高。

 用自我监督学习克服语言先验的视觉问答_第4张图片

 公式(6)为问题依赖损失。因为Lqd只对不相关的问题-图像对(Q, )有效(不相关对的ci=0),所以公式(6)中省略了ci。由公式可知,当我们的预测概率P越小的时候,损失Lqd也会越小,而概率越小代表我们预测的越不准确。又因为我们这是不相关对的损失函数,越不准确也就代表预测答案越不接近真实答案,即模型受语言先验影响的程度越小。Lqd也可以看做我们对于不相关对所评价的相关程度(越小越不相关),所以最小化不相关对的相关度可以显式地防止VQA模型过度受语言先验的影响。 

综上公式,模型的总损失Lself就由VQA模型损失Lvqa和问题依赖损失Lqd两部分组成。

  

实验结果

下图表示各种模型与本文基于UpDn模型准确率的对比,可见本文的方法不仅可以提高UpDn模型的整体性能(交叉熵损失+14.35%,多标签损失+16.06%),而且优于性能最好的SCR方法(交叉熵损失+3.13%,多标签损失+8.09%),而且在“Yes/No”问题类型上,获得极高的准确率(87.75%和86.53%),这表明确实有效地克服了语言先验(因为当某个答案的得分相较其他特别高的时候,就越容易产生语言先验,所以越简单的答案往往越能反映语言先验的程度)。

用自我监督学习克服语言先验的视觉问答_第5张图片

注:第一部分是非基于注释的模型性能,第二部分是基于注释的模型性能,第三部分是加上本文辅助任务模型的性能。可见,无论使用哪种VQA损失,本文的方法都优于上述所有方法(包括性能最好的方法)。†表示重新实现我们的基线。 为交叉熵VQA损失, 为多标签VQA损失。

对于问题类型‘is this’

训练集上的回答得分:

用自我监督学习克服语言先验的视觉问答_第6张图片

测试集上的回答得分:

用自我监督学习克服语言先验的视觉问答_第7张图片

预测值的回答得分:

你可能感兴趣的:(学习)