0、最近在看一篇关于在模型层面上做可解释的图神经网络,记录一下。
论文链接:XGNN:Towards Model-Level Explanations of Graph Neural Networks
图神经网络通过聚合邻居节点的信息来学习节点特征,虽然在许多图任务上取得了较大的成功,但是GNN和传统的深度学习模型一样,是黑盒模型。如果不能理解GNN模型,就不能完全信任它们。
本文作者提出了一种在模型层面上解释GNN的方法。通过训练一个图生成器(Graph Generator)使得生成的图模式/子图能最大化GNN模型的某种预测,作者将图生成器描述为一种强化学习任务,来预测每一步应该在哪两个节点间添加边(连接哪两个节点的边),此外在生成器上添加了一些图规则,使得生成的图是有效的(比如一个节点的度小于4,这个要根据实际问题考虑)。在合成的数据集以及真实数据集上都取得了不错的效果。
GNNs虽然在图任务上效果不错,但它仍是一个黑盒模型。如果不去解释它,就不能理解它,在一些安全、隐私、公平极重要的场合,我们就不能相信这些模型。
关于目前的可解释性方面的研究,主要包括两大块:实例级和模型级。
实例级的解释是通过确定输入中的重要特征来解释对于给定实例的预测(如CNN识别猫的时候,一步步可视化特征)。主要方法有:(1)基于梯度的方法(2)中间特征图可视化(3)基于遮挡的方法。
模型级的解释通过研究哪些输入模式可以导致某种预测结果 去 解释模型的一般行为。主要方法:输入优化。
对于实例级的解释,要通过解释大量的输入实例才能相信模型,需要大量人力物力,而模型级的解释更加普遍与高级。
文章的具体做法:训练一个图生成器,以便生成的图可以解释深度图模型。将这个过程表述为一种强化学习,每一步,给定一个图,图生成器预测连接哪两个节点,生成一个图,此外,加入一些图规则,使生成的图是有效的。注意:图生成器只是一个框架,可以根据具体数据集和要解释的GNN推广到任何合适的图生成方法。
图卷积操作的定义:
X_i是第i个图卷积层的输入(n个节点*d维),X_{i+1}是第i个卷积层的输出。A^=A+I,A是邻接矩阵,I是单位矩阵,D是节点的度矩阵,W是要训练的参数,f(·)是非线性激活函数。
模型级的解释 通常会生成 可以最大化预测结果的 优化输入。随机初始化输入,依次迭代更新输入直到达到某一目标(比如预测某一类概率最大)。这样的优化输入可以被认为是对模型某一行为(作出预测)的解释。输入优化类似于深度神经网络,只不过DL训练的是参数,而优化输入中模型参数是固定的,要去训练输入。
虽然它可以用于深度神经网络。但不能用于图数据,原因有三:
(1)图结构由离散的邻接矩阵表示,不能通过反向传播直接优化邻接矩阵。
(2)图像images中,抽象后的图是高层次的特征表示,在图Graphs的情况下,抽象是没有意义的。
(3)得到的图Graph可能对生物、化学规则是无效的,需要添加一些图规则。(原子节点的度不能超过其最大化合价)
实例级解释:最近的GNN解释工具GNN Explainer 提出通过学习soft masks在实例级上解释深层图模型。以一个实例为例,将soft mask应用于图的边缘和节点特征,并更新mask使预测结果与原始预测结果保持一致。然后通过对mask进行阈值选取一些图边和节点特征,并将它们作为重要的边缘和特征进行实例的预测。
---------------------------------------------------万能解释工具GNN Explainer------------------------------------------------
我先去看看 GNN Explainer(知乎论文分享)
1、不去探究模型内部,专注于信息本身,研究哪些信息对预测结果是重要的。
2、从图结构+特征信息解释。
3、从实例级解释。
核心:将一个已经训练好的GNN和其预测结果作为输入,然后通过输出一个子图以及该子图上更少的特征,表示其输出最大程度的影响了该GNN的预测结果。
---------------------------------------------------万能解释工具GNN Explainer------------------------------------------------
模型级解释:给定一个训练好的GNN,模型级的解释就是要解释哪些子图/图模式导致了某种预测y.
XGNN要做的事就是:找出来这个子图/图模式。用公式表示就是:
G*是我们需要的子图/图模式,P概率,f(.)表示训练好的GNN模型,c是标签。
手工去分析得出子图/图模式,费时费力,作者提出了一种图生成的方法。
核心:对于每一步,图形生成器都会基于当前图形生成一个新图形。数学上:
输入是第t步的图Gt,其中,g(.)是图生成器,Xt 第t步的节点特征矩阵,At邻接矩阵。
输出是第t+1步的图Gt+1,X节点特征矩阵,A邻接矩阵。
然后在预训练的GNNs模型的指导下 训练图生成器g(.)。这里作者将图生成的过程描述为一个强化学习问题。
---------------------------------------------------------强化学习------------------------------------------------------------------
就是一个序列做决策的问题。每一步都会对最终结果产生影响,每一步都会产生一个回馈,怎么做才能使得每一步都达到最优?
---------------------------------------------------------强化学习------------------------------------------------------------------
首先用GCNs学习聚合节点的特征X^。其中Gt是当前图,C是一个候选节点的集合。(我们要选两个节点,再将它们连接起来,生成新图)
然后用一个多层感知器预测出 开始节点 的概率,再用softmax选出来一个概率最大的节点,作为初始节点。(公式6表示,只能从图Gt中选开始节点,因为候选节点不能作为开始节点)
然后选择一个 结束节点 。x^start代表 已经选择的那个节点的特征。(公式8表示 除去开始节点 的 剩余节点中 选择一个作为结束节点)
上图2中完整的反映出了这个步骤:
---------------------------------------------------------数学盲区------------------------------------------------------------------
---------------------------------------------------------数学盲区------------------------------------------------------------------
action at在第t步的损失函数:
LCE(·,·)为交叉熵损失,Rt为第t步的奖励(Reward)函数(如果at能生成一个对预测结果c类分数较高且有效的图,Rt就越大),奖励Rt有两部分组成,第一部分Rt,f来自训练模型f(.)的反馈,第二部分Rt,r来自图规则。对第t步,奖励包含了对图Gt+1的中间奖励以及最终奖励。
Rt,f(Gt+1)是中间奖励
λ1和λ2是超参数
在合成数据集、真实数据集上 评估。
1、合成数据集:Networkx software package 得到的(一个graph中有环存在,label=1,否则label=0)节点是没有label的,主要研究GNN扑捉图结构的能力。
2、真实数据集:MUTAG,分为两类(化合物会对细菌产生诱变label=1,否则label=0)节点是有label的,每个节点(原子)可能是碳氢氧氟氯…主要研究GNN捕获图结构与节点标签的能力。
效果很好
在本文中,我们提出了一种新的方法XGNN,在模型层次上解释图模型。具体地说,我们提出通过图的生成来寻找能够最大化某种预测的图模式。我们将其表示为一个强化学习问题,并迭代生成图模式。我们训练一个图形生成器,对于每一步,它预测如何向当前图中添加一条边。此外,我们加入了一些图规则,以鼓励生成的图是有效的和可理解的。最后,我们在合成数据集和真实数据集上进行了实验,以证明我们提出的XGNN的有效性。实验结果表明,生成的图有助于发现哪些模式可以最大限度地预测训练后的神经网络。生成的解释有助于验证和更好地理解经过训练的gnn是否能以我们预期的方式做出预测。此外,我们的结果还表明,所产生的解释有助于改进训练的模型。