【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL

【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第1张图片


  • KDD '22: Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining
  • August 2022
  • Pages 1696–1705
  • 论文地址

本文介绍的论文是中科大王翔教授等人在KDD2022上发表的《Causal Attention for Interpretable and Generalizable Graph Classification》。

作者强调了当前基于注意力和池化的GNN在图分类中的泛化问题,并且提出了一种新的用于图分类的因果注意力学习策略(CAL),使GNN在过滤掉捷径特征的同时利用因果特征,最后在合成数据集和真实数据集上的大量实验证明了CAL的有效性。


一、背景

目前大多数图神经网络GNN在图分类这项任务中,遵循learning to attend这种模式,这能够最大程度学习图数据与标签之间的关系。

但是这种范式使得模型学习到的映射是基于统计相关性的,忽略了数据之间的因果关系,没有区分特征的因果效应和非因果效应,这会导致模型将非因果特征作为捷径特征,将其用来进行预测。

这也是导致ODD数据泛化性不好的一个原因,因为模型是将捷径特征作为预测,如果ODD数据不遵循这种数据分布,那么将会使得GNN的泛化能力变差。

所以作者提出了新的想法CAL,设计一个模型能够将图数据中的因果图和琐碎图进行分离,降低嘈杂因子对于该项任务的影响。

二、模型方法

2.1 结构因果模型

【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第2张图片
对于图分类的任务,我们可以定义上面的SCM图,这幅图显示了该任务当中的一些变量之间的因果关系。

  • G:代表图数据
  • C:代表图数据中的因果特征
  • S:图数据中的捷径特征,也就是混杂因子,会导致模型不按正确的方式进行预测
  • R:图的表征信息
  • Y:图的类别信息

对于因果理论来讲,图数据中会存在两种数据特征,一种是捷径特征,另外一种是因果特征,我们希望分类任务当中我们使用因果特征,也就是走G->C->R->Y这条道路,但是数据种往往会存在非因果特征,模型会使用这部分数据用于预测,也就是图中上面的路。

这也就是说如果我们的数据分布不一致,那么显然模型的泛化性一定会降低,因为在测试集种很有可能没有捷径这部分特征。

说了很久捷径特征和因果特征,这里我举个例子说明一下这两个特征的区别:

比如我们现在有一个图像分类的任务,是要判断一张图片是不是羊,我们的训练集中所有有羊的图片中都会有草地,但是验证集中只有羊,此时如果使用模型训练,模型就会将草和羊做关联,认为凡是有草的图片就会有羊,没有草的图片即使有羊也不会识别成功,这时该网络就犯了错,因为它的学习方式有误,我们识别羊应该通过图像的纹理判断是否具备羊的特征,而不是通过草这个特征,即便该模型在训练集中可以做到百分百准确率,但是一旦换一张一只羊在南极的照片,他就会识别失败。

显然这个例子中的捷径特征就是草这个因素,因果特征应该是羊图像的纹理特征,叫捷径特征也就是模型利用该特征偷懒了,并没有学习到数据的本质关系。

2.2 后门调整

该因果关系图中,存在后门路径C<-G->S->R->Y,其中S在C和Y之间就扮演了混杂因子的角色,尽管C和Y没有直接联系,但是存在后门路径,他会建立一个错误的联系通过S,进而导致模型出现错误的判断,所以切断后门路径使GNN利用因果特征至关重要。

因果理论中为我们提供一种方法来消除后门路径:利用do算子进行干预

【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第3张图片

2.3 因果和琐碎参与图

为了能够同时学习到图数据中的因果特征和捷径特征,作者提出了因果参与图和琐碎参与图这两个概念,希望通过这两个图来分类特征信息,然后用因果图进行分类,而琐碎参与图视为没有因果特征信息的图。
【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第4张图片
目标就是通过掩码分数来从完整图数据中捕获因果参与图和琐碎参与图,学习到这两个图不仅可以指导GNN的表示学习,还可以回答GNN使用什么知识进行预测,这对于解释性、隐私性和公平性的应用至关重要。

2.4 因果注意力学习

作者提出了CAL框架来实现上面的想法

2.4.1 计算注意力掩码分数

作者为了实现干预,所以需要从图数据中分离因果特征和捷径特征,这也就对应上面说到的两个图,但是作者在学习这两部分特征时使用了注意力模块,这个模块会根据节点以及边的特征信息来学习出对应的注意力分数。
【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第5张图片
利用这些分数,我们可以形成两个矩阵,分别是 M a M_a Ma M x M_x Mx以及他们的互补矩阵,所谓互补矩阵就是使用全1矩阵减去他们得到的矩阵,然后使用这两个矩阵提取图数据中的信息。

在这里插入图片描述

2.4.2 分离因果图和琐碎图

上面给出了因果图和琐碎图的计算方式,然后作者采用了两个GNN层获取了图的表示,使用读出函数和分类器进行预测,因为我们最终是要做分类,所以需要使用Read Out函数读取整幅图的数据,将其转化为一个向量输入到后面的分类器中进行分类。

在这里插入图片描述
因为因果参与图是真正用来提取信息进行分类任务的,所以我们要基于因果参与图的预测结果构造损失函数,也就是分类任务常用的交叉熵损失函数。

【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第6张图片

但是我们的琐碎参与图也会存在输出,因为我们不想让它参与模型的预测,也就是不希望捷径特征和Y产生因果关系,所以希望琐碎参与图的输出与Y无关系,那么就希望它的输出更加接近均匀分布,所以对于琐碎参与图也会有相应的损失函数:

【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第7张图片
该损失函数利用了KL散度,它就是用于衡量两个分布的相似性,因为我们希望它不参与预测,也就是希望它的输出更加接近均匀分布。

2.4.3 因果干预

为了缓解混杂因子对于模型的影响,需要进行后门调整,作者的想法是对混杂因素进行分层,将因果参与图与琐碎参与图进行随机配对,组成一个干预图。

【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第8张图片
就是将因果参与图的读出表示随机与琐碎参与图的读出表示进行相加,然后将这个向量用于分类器。

最后,整个CAL框架模型的损失可以定义为这三个损失函数之和:

在这里插入图片描述

三、实验结果

3.1 合成数据集

【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第9张图片

3.2 真实数据集

【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第10张图片

四、总结

该作者从因果角度重新审视了用于图分类的GNN建模,发现当前的GNN学习策略倾向于利用捷径特征来支持他们的预测,然而,捷径特征实际上扮演了一个令人困惑的角色。它在因果特征和预测之间建立了一条后门路径,从而误导GNN学习虚假相关性。为了减轻混杂效应,提出了GNN的因果注意学习策略(CAL)。CAL以因果理论的后门调整为指导,它鼓励GNN利用因果特征,而忽略捷径部分。广泛的实验结果和分析证实了其有效性。

CAL算法流程
【图神经网络论文整理】(一)—— Causal Attention for Interpretable and Generalizable Graph Classification:CAL_第11张图片

你可能感兴趣的:(图神经网络,神经网络,分类,深度学习,人工智能)