主要内容有:Philoshopy(哲学思想)、Explainer(解释器)、Explanations(解释)、Explainer Algorithm(解释器算法)、Explanation Metrics(解释度量)
该模块提供了一组工具来解释 PyG 模型的预测或解释数据集的基本现象,详细信息可以参考“GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks”
。
我们使用torch_geometric.explain.Explanation
类来表示解释,该类是一个Data
对象,包含数据的节点、边、特征和任何属性的掩码。
torch_geometric.expllain.Explainer
类设计用于处理所有可解释性参数(有关更多详细信息,请参阅torch_geometric.explainn.config.ExplainerConfig
类):
torch_geometric.expllain.algorithm
模块中,例如GNNExplainer
explanation_type="phenomenon"
或者explanation_type="model"
)mask="object"
或者mask="attributes"
)threshold_type="topk"
或者threshold_type="hard"
)该类允许用户轻松比较不同的可解释性方法,并在不同类型的掩码之间轻松切换,同时确保高层的代码框架保持不变。
基础类:object
图神经网络实例级解释的一个解释器类。
参数:
torch.nn.Module
)——要解释的模型ModelConfig
,默认为None
node_mask_type
。默认为NoneThresholdConfig
,默认为None方法
get_prediction(*args, **kwargs)→ Tensor
:返回模型对输入图的预测。get_masked_prediction(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], node_mask: Optional[Union[Tensor, Dict[str, Tensor]]] = None, edge_mask: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs)→ Tensor
:返回应用了节点和边掩码的输入图上的模型预测。__call__(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]
:计算给定输入和目标的GNN的解释。torch.no.grad()
计算的。
get_target(prediction: Tensor)→ Tensor
:从给定的预测中返回模型的目标。node_mask_type
。默认为None
基础类:Data
、ExplanationMixin
持有同质图的所有已得到的解释。解释对象是Data对象,可以包含节点属性和边属性。如果需要,它还可以保存原始图形。
参数:
node_mask——形状为[num_nodes, 1], [1, num_features]
或[num_nodes, num_features]
的node-level掩码,默认为None
edge_mask——形状为[num_edges]
的edge_level掩码,默认为None
kwargs——其他属性参数
方法:
validate(raise_on_error: bool = True)→ bool
:验证解释对象的正确性。
get_explanation_subgraph()→ Explanation
:返回归纳子图,其中所有属性为零的节点和边都被屏蔽掉了。
get_complement_subgraph()→ Explanation
:返回归纳子图,其中具有任何属性的所有节点和边都被屏蔽掉。
visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None)
:通过对所有节点的节点掩码求和,创建节点要素重要性的条形图。
参数:
visualize_graph(path: Optional[str] = None, backend: Optional[str] = None)
: 使具有与边重要性相对应的边不透明度的解释图可视化。
参数:
graphviz
”、“networkx
”)。如果设置为“None”,将根据可用的系统包使用最合适的可视化后端。(默认值:None)
基础类:HeteroData
、ExplanationMixin
包含所有已获得的对异构图的解释。解释对象是HeteroData对象,可以包含节点属性和边属性。如果需要,它还可以保存原始图形。
方法:
validate(raise_on_error: bool = True)→ bool
:验证解释对象的正确性。get_explanation_subgraph()→ HeteroExplanation
:返回归纳子图,其中所有属性为零的节点和边都被屏蔽掉了。get_complement_subgraph()→ HeteroExplanation
:返回归纳子图,其中具有任何属性的所有节点和边都被屏蔽掉。visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[Dict[str, List[str]]] = None, top_k: Optional[int] = None)
:通过对每个节点类型的所有节点的节点掩码求和,创建节点特征重要性的条形图。1) ExplainerAlgorithm
:用于实现解释器算法的抽象基类。
方法
- abstract forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]
:计算解释。
参数:
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- abstract supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。
- property explainer_config: ExplainerConfig
:返回已连接的解释器配置。
- property model_config: ModelConfig
:返回已连接的模型配置
- connect(explainer_config: ExplainerConfig, model_config: ModelConfig)
:将解释器和模型配置连接到解释器算法。
2) DummyExplainer
:返回随机解释的伪解释程序(用于测试目的)。
基础类:ExplainerAlgorithm
方法
- forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], edge_attr: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]
:解释计算
参数:
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。
GNNExplainer
:来自 “GNNExplainer: Generating Explanations for Graph Neural Networks” 论文中的GNN-Explainer模型用于识别在GNN预测中起关键作用的紧凑子图结构和节点特征。
基础类:ExplainerAlgorithm
有关使用GNNEexplainer的示例,请参见examples/explaine/gnn_explainer.py、examples/explain/gnn_eexplainer_ba_shapes.py和examples/explain/gn_explainer_link_pred.py。
参数:
coeffs
中默认设置的附加超参数。方法:
forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation
:计算解释supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。CaptumExplainer
:一种基于Captum的解释器,用于识别在GNN的预测中起关键作用的紧凑子图结构和节点特征。forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]
:计算解释supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。PGExplainer
: "Parameterized Explainer for Graph Neural Network"论文中的PGExplainer模型。train()
进行训练:explainer = Explainer(
model=model,
algorithm=PGExplainer(epochs=30, lr=0.003),
explanation_type='phenomenon',
edge_mask_type='object',
model_config=ModelConfig(...),
)
# 针对各种节点级别或图级别的预测进行训练:
for epoch in range(30):
for index in [...]: # Indices to train against.
loss = explainer.algorithm.train(epoch, model, x, edge_index,
target=target, index=index)
# 获得最终解释:
explanation = explainer(x, edge_index, target=target, index=0)
参数:
- epochs:要训练的epochs数。
- lr:学习率,默认为0.003
- kwargs:用于覆盖coeffs
中默认设置的附加超参数。
方法:
- reset_parameters()
:重置模型中所有科学系参数
- train(epoch: int, model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)
:训练基础的解释者模型。需要在能够做出预测之前被调用。
参数:
- epoch:训练阶段的当前阶段。
- model:要被解释的模型
- x:同质图的输入节点特征。
- edge_index:同质图的输入边索引。
- target:模型的目标
- index:对模型输出的索引进行解释。需要是一个单独的索引。(默认值:None)
- kwargs:传递给model的其他关键字参数。
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation
:计算解释
参数:
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。
5) AttentionExplainer
:使用基于注意力的GNN(例如,GATConv、GATv2Conv或TransformerConv)产生的注意力系数作为边解释的解释器。各层和头部的注意力得分将根据reduce argument进行汇总。
基础类:ExplainerAlgorithm
参数:
- reduce:降低各层和头部注意力得分的方法。(默认值:“max”)
方法:
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation
:计算解释
- - supports()→ bool
:检查解释器是否支持self.explainer_config
、self.model_config
中提供的用户定义设置。
解释的质量可以通过各种不同的方法来判断。PyG支持以下开箱即用的指标:
groundtruth_metrics
:将解释掩码与ground-truth解释掩码进行比较和评估。fidelity
:评估一个Explainer给出的Explanation的真实度 ,参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文characterization_score
:返回组件式特征化分数(the componentwise characterization score),参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文fidelity_curve_auc
:返回真实度曲线的AUC,参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文unfaithfulness
:评估一个Explanation对一个不足的GNN预测因子的真实度,参见 "Evaluating Explainability for Graph Neural Networks"论文