什么是OOD(out-of-distribution)、ID(in-distribution) 数据?
举个例子,现在有一个已经训练好的图像分类模型,在训练集上可以很好地分类猫和狗的图片(也就是说训练的答案类型中只有猫和狗)。现在我们输入一个蛇的图片,这个模型是肯定输出不了“蛇”这个类别的,因为他都没见过。但是,如果输入一个猫或狗的图片,这个分类器肯定是可以进行精确分类的。那么这个蛇的样本就是OOD数据,而猫或狗的样本就是ID数据。在VQA模型中则是通过答案分布来区分OOD和ID,其对应的就是有偏样本和无偏样本。
除了常见的语言先验(language priors)问题,作者还发现大多数缓解语言先验的模型虽然在VQA- cp v2上获得了收益,但是在VQA v2上的ID性能反而降低了。这表明这些方法的成功仅仅来自于对模型的其他方向的偏见,而不是赋予它们推理能力和对语言先验的鲁棒性。所以产生了一个新问题——权衡问题。归根结底,语言先验是因为过度依赖有偏见样本中的偏见信息,权衡问题则是因为有偏见样本的重要性被削弱。所以作者想要构建一个模型,使其能够精确地利用有偏差的样本,并获取给定任务的内在信息,这样上述的两个问题就可以同时得到缓解。
下图为作者的LMH+MMBS方法、普通方法UpDn、去偏方法LMH三者的定性比较。结果表示LMH损害了ID性能,而MMBS在提高OOD性能的同时保持ID性能。
贡献:
●作者提出了一种自监督对比学习方法(MMBS),通过充分利用有偏差的样本,有效地缓解了语言先验问题和ID-OOD性能权衡问题。
●作者还提出了一种区分有偏样本和无偏样本的算法,使得在对比学习中可以区别并处理它们。
本文的方法适用于各种主干VQA模型。此模块计算最小化多标签软损失Lvqa。tgti为每个标签(G-T多标签)对应的目标分数,δ表示sigmoid函数。
目的:构造排除有偏信息的正样本。
构建正问题两种方法:混排(Shuffling)、删除(Removal)
本文提出的四种应用策略:S(仅Shuffling)、R(仅Removal)、B(两个都用)、SR:non-yesno(如“Num”和“Other”)问题使用混排,yesno(如“Y/N”)问题使用删除。
通过采用上述任一策略,输入一个样本 ,可以得到正样本 。负样本 ,(b≠i)为同批次的其他样本。B为训练批次大小。
本文提出了一种新的算法,包括以下的三个步骤:
(1)计算答案Aj在问题类别Ck中出现的频率(以每类问题的总得分作为频率)。第ith个样本的问题类别、G-T答案和软目标得分各表示为Ci、Ai和tgti,MCk是类别Ck的所有样本的数目。若有多个标签的答案Ai,则分别计算每个答案的得分。
的值越低,说明Aj和Ck之间的伪相关性越弱,因此认为该样本是无偏的。并且引入了一个超参数β∈[0,1]来控制无偏样本的比例。
(2)确定无偏答案比例。当一个问题类型答案分布的熵较低时,表示大多数的答案类型只与少数的样本相关联,因此就要使无偏答案的比例更高。否则,它应该更低。如下图表示:“Does the”的yesno问题和“How many”的non-yesno问题的答案分布。前者熵值低,后者熵值高。
于是,作者提出了一个基于熵的校正因子WCk来动态调整每个类别Ck的β,最后得到无偏答案比例PCk = WCk * β。
其中E表示(所有答案类型熵值的总和),SUM表示FreqCk的总和,mean()表示对所有答案类型总熵值E求平均,香农熵Entropy(x)= -x*log2x,其函数图像如下:
当熵ECk较低时,WCk更接近于1,否则WCk更接近于0。如下图所示:
(3)选择无偏样本。对于每个问题类型Ck,都得到一个无偏见的答案列表(通过PCk得到无偏答案个数a,再从中取分数最低的a个)。然后对于一个输入样本,就可以通过判断其G-T(最高分)答案是否属于这个列表,是则为无偏样本,否则为有偏样本。对于无偏样本,作者提出使用原始样本作为它的正样本。而对于有偏样本,则采用上面提到的策略来构建它的正样本。
给模型输入一个样本(Ii, Qi),同一批次中又会有两个输入:正样本和负样本,其中b≠i。经过VQA模型后,得到原始样本F(Vi, Ti)、正样本F(Vi, Ti +)和负样本F(Vb, Tb)的多模态融合表示,分别记为锚点a、正p、负nb。然后作者使用了余弦相似度cos(·)作为评分函数,得到对比损失Lcl的公式为:
训练过程中通过最小化它,模型就可以专注于来自正问题的无偏信息。MMBS的总损失L为:
L =Lvqa + α * Lcl,其中α为Lcl的权重。
上图为本文方法基于基础模型(BAN、UpDn、LXMERT)和去偏方法(LMH 、SAR )在两个测试集上的性能分析。“Gap”表示原模型相比加MMBS方法的性能提升额度。可见,这些模型融合了本文的MMBS方法后,不仅在OOD测试集VQA-CP v2上模型性能有提升,在ID测试集VQA v2上的性能也有轻微提升(但在LMH模型上提升较大)。