Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征

Proactive Pseudo-Intervention: Causally Informed Contrastive Learning For Interpretable Vision Models是2021年发表在CVPR上的一篇关于因果推断与对比学习相结合的文章。

本文出发点

当前深度学习模型大多是基于统计模型的数据驱动方式来学习,这种黑盒子的方式虽然可以直接通过数据来学习其隐含的相关性,但是也存在着许多的问题,
image from paper《*Towards Causal Representation Learning*》
例如:模型的可解释性差,以及当数据量不足或者数据类别分布不均匀时,模型出现的过拟合现象。为了避免这两个问题,作者提出用因果的角度来对模型进行训练,通过使得模型更多的去关注图像中与label有关的因果特征从而提升模型对背景信息等非因果信息的鲁棒性,同时可以通过事后可视化的方式标记出saliency mapping,从而使得我们对模型有更好的解释性(e.g. 如果saliency mapping更多的集中于目标物体,那么会大大提升我们对模型的信任程度)Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征_第1张图片
本文的一个效果图,其中WBP是本文提出的一种迭代的计算saliency mapping的方法,这个后面会详细说。通过本文方法,对比LRP我们可以看到模型将注意力更多的集中到了真正有因果联系的物体上,而几乎不会去关注背景信息,这使得我们的模型有了相当好的out-of-domain能力。

具体实施过程

本文基于这么一个前提假设来实施因果推断:真正的因果关系可以通过人为的干预来影响最终的目标结果,即 the target label will be changed only if causally-relevant features are perturbed。于是本文利用对比学习去生成伪干预,这也正是本文题目的由来。通过这种对因果信息进行伪干预的方式,本文一次性实现了因果信息推断、鲁棒学习、预测模型的可解释性。同时本文将saliency mapping和contrastive learning相结合,共同来促进模型的学习,直接来看本文框架Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征_第2张图片
首先通过WBP,模型会对图像的一个因果特征区域进行计算,之后模型会通过将求得的特征区域进行box,并且依照box生成一个mask图(当然这里也可以直接不box,将不规则特征区域生成mask图,就是文章中所谓的 hard-masking,后面作者有对照实验),作者将其称之为causal mask,因为我们希望这里的mask可以屏蔽掉图像的因果信息。之后作者将被mask后的图像输入到分类模型中,企图让分类模型去分类这个mask后的图像。这里作者基于这么一个假设:如果图像不存在因果信息(具体到本模型中,就是因果信息被上一步生成的mask完全屏蔽),那么分类器进行分类时的结果将是错误的。我们来看具体的公式表达:
Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征_第3张图片
其中 x i ∗ x^*_i xi代表第 i i i个图像 x i x_i xi所生成的mask后的图像; s m ( . ) s_m(.) sm(.)代表求saliency mapping的过程;(2)式是一个将saliency mapping转化为mask的操作,其中 ω \omega ω σ \sigma σ是两个阈值超参数,控制阈值左右的值逼近1或0;(3)中的 ⊢ y \vdash y y意思是label y y y的相反值,例如在二分类问题中,如果真实的label y = 1 y=1 y=1,那么这里就是 ⊢ y = 0 \vdash y=0 y=0,而对于多分类(N类)问题,如果 y = i y=i y=i,那么 ⊢ y = j ( j ≠ i ) \vdash y=j (j≠i) y=j(j=i)这个具体是如何实现的,由于没有源码,我个人也不得而知。通过这种方式,作者构建出了 x i ∗ x^*_i xi,并且在文中,作者将其称之为 x i x_i xi的负样本。
WBP方法是通过对 f θ ( . ) f_{\theta}(.) fθ(.)进行反向求得的,意在找出真正的因果部分,然而如果仅仅通过上面的公式进行训练,产生的saliency map可能会包含所有的可能是因果特征的部分,因此未来避免模型出现这种情况,作者又加了一个损失限制: L r e g = ∣ ∣ s m ∣ ∣ 1 , ( m = 1 , . . . , M ) . L_{reg }= ||s_m ||_1 , (m = 1,...,M). Lreg=sm1,(m=1,...,M).通过这个限制,模型可以避免产生大范围的saliency map(这里作者显然默认了一个假设,通常一张图像的因果部分仅仅占图像的很少一部分),这样salency map只会显示最具有因果特征的部分,从效果图也可以看出来,本文方法所产生的saliency map对比与传统方法显然暗淡很多,仅仅在最可能的部分才有显示。
此外未了避免 f θ ( . ) f_{\theta}(.) fθ(.)对图像是否被mask产生判别结果的影响,作者又引入了另一条线,也就是上图的 x j x_j xj部分,通过将另一张图的mask贴到 x i x_i xi上,从而生成没有没mask掉因果特征的图,在这里作者将其称之为 x i x_i xi的正样本。由于正样本没有被刻意mask掉因果特征部分,因此在进行分类时,我们还要力求将其分到正确的类别在这里插入图片描述
最后,模型通过联合优化几个损失,来进行整个流程的优化在这里插入图片描述
这里再添一点我个人的理解:本文名为对比学习,实则在作者的流程中并没有给出显式的对比学习过程,个人认为这里已经构造出了因果正负样本,如果需要进行对比学习,也完全可以用这样本来进行。

WBP的实现

Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征_第4张图片

通俗来说,WBP是通过计算每个像素点对最后分类结果的影响来生成saliency map的。
在这里插入图片描述
这里我们将输入简化为向量 x 0 x^0 x0,其中 x l x^l xl是第 l l l层的特征向量, W ~ l \tilde{W}^{l} W~l代表从第 l l l层到最后的所有网络计算。在这里插入图片描述
然后对于 x x x的第 k k k个元素,其saliency map可以表示为这种形式,其中 [ W ~ 0 ] m k [\tilde{W}^0]_{mk} [W~0]mk是一个单独的元素,通过与 [ x ] k [x]_k [x]k相乘便可以得到我们想要的结果。
所有接下来我们的关键就是计算 W ~ l \tilde{W}^l W~l。我们让 g l ( x l ) g^l (x^l ) gl(xl)表示单独的第 l l l层网络,也就是在这里插入图片描述
那么我们就可以通过迭代的来从后往前求 W ~ l \tilde{W}^l W~l,通过在这里插入图片描述
其中 G ( . , . ) G(.,.) G(.,.)代表更新条件,依照具体网络结构来说的,作者给出了常见的 G ( . , . ) G(.,.) G(.,.)的一些表示Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征_第5张图片

实验分析

首先作者给出了和其他几种方法求得的saliency map的定性比较结果Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征_第6张图片
可以看到,对于gradient-based model(2、3),其saliency map十分分散或者微弱,这表示其缺乏特征性,由此训练结果的精度难以达到要求。对于LRP及其变体,尽管可以产生shaper的saliency map,但是其无法将目标与背景分开,因为其认为背景有助于我们的目标任务,然而这是错误的。
之后作者又给了一个非常有趣的实验Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征_第7张图片
当我们移除saliency map中的一些反应比较强烈的点时,对模型性能的影响。可以看到,当仅仅被移除了100个像素点时,WBP方法训练的模型其精度就会产生极大的下降,这反应了这个模型几乎不会利用背景信息来进行决策,这正是我们想要的结果。
此外作者在医学图像(包括后面的补充材料)中还测试了不同的mask的效果:
Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征_第8张图片
Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征_第9张图片
Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征_第10张图片
其中hard我们可以看做是几乎严格的0-1 mask,而soft则是不那么严格的0-1 mask,可以看到使用box方法的mask始终要优于不规则形状的mask。这里作者通篇没有给出解释,个人认为是由于在进行不规则形状的mask时,mask后的图像虽然理论上去掉了因果部分的像素,但模型依然可以依照去掉部分的轮廓获得部分信息,因此无法完全满足作者一开始的理论分析

个人总结

本文通过将因果推断与对比学习(这里个人感觉对比学习部分有点牵强)相结合的方式,实现了模型对于因果信息之外的信息的鲁棒性,同时通过生成的saliency map可以给人对模型以直观的信心。但是本文依然有些地方个人认为有待提高:
1、首先是本文对于负样本的判别上,作者意图将去掉因果特征的样本判别为错误label(或者说不能被判别为正确标签),然而这使得模型依然有着想探索背景信息的倾向:比如说猴子和鱼和鸟的分类,模型无法在图像上寻找到鸟时,模型会常识用背景信息 天空 来判断其label不属于鸟,这就违背了本文的初衷。个人认为对于负样本,应该使得模型对其判别的结果是所有类的平均,即图像上没有鸟,只有天空时,我的判别结果对于三个类都是 1 3 \frac{1}{3} 31的概率。
2、对于正样本的mask方式,作者意图将其他不相关的图像的mask用在目标图像上,虽然理论上没有问题,然而在实际应用时,如果数据集中的目标大多都处在图像中间,那么其他图像的mask依然可能将目标图像的因果特征给大大的屏蔽掉,由此使得 L a d L_{ad} Lad损失不合法。个人认为可以取第一步mask的区域的其他区域进行mask。
3、对于细长的目标,特别是对角线分布的目标,本文的box方式可能会出现mask过度的现象,如何对box的范围进行权衡,将是本文需要面临的一个问题。

你可能感兴趣的:(Proactive Pseudo-Intervention——因果推断与对比学习结合,用于寻找因果特征)