【可解释学习】PyG可解释学习模块torch_geometric.explain

PyG可解释学习模块torch_geometric.explain

  • Philoshopy
  • Explainer
  • Explanations
  • Explainer Algorithm
  • Explanation Metrics
  • 参考资料

torch_geometric.explain是PyTorch Geometric库中的一个模块,用于解释和可视化图神经网络(GNN)模型的预测结果。它提供了一些方法来解释模型的预测结果、边权重和节点重要性。

主要内容有:Philoshopy(哲学思想)、Explainer(解释器)、Explanations(解释)、Explainer Algorithm(解释器算法)、Explanation Metrics(解释度量)

Philoshopy

该模块提供了一组工具来解释 PyG 模型的预测或解释数据集的基本现象,详细信息可以参考“GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks”

我们使用torch_geometric.explain.Explanation类来表示解释,该类是一个Data对象,包含数据的节点、边、特征和任何属性的掩码。
torch_geometric.expllain.Explainer类设计用于处理所有可解释性参数(有关更多详细信息,请参阅torch_geometric.explainn.config.ExplainerConfig类):

  • 使用的算法来自torch_geometric.expllain.algorithm模块中,例如GNNExplainer
  • 要计算的解释类型(例如explanation_type="phenomenon"或者explanation_type="model"
  • 节点和边的不同类型掩码(例如mask="object"或者mask="attributes"
  • 掩码的任何后处理(例如threshold_type="topk"或者threshold_type="hard"

该类允许用户轻松比较不同的可解释性方法,并在不同类型的掩码之间轻松切换,同时确保高层的代码框架保持不变。

Explainer

Explainer
基础类:object
图神经网络实例级解释的一个解释器类。
参数

  • model(torch.nn.Module)——要解释的模型
  • algorithm(解释器算法)——解释算法
  • explanation_type(解释类型或str)——要计算的解释类型。可能的值为:
    • “model”:解释模型预测
    • “phenomenon”:解释模型试图预测的现象。
      在实践中,这意味着解释算法将计算其相对于模型输出(“model”)或目标输出(“phenomenon”)的损失。
  • model_config——模型配置,参见ModelConfig,默认为None
  • node_mask_type——要应用于节点的掩码类型。可能的值为:
    • “None”:不会在节点上应用任何掩码。
    • “object”:将屏蔽每个节点。
    • “common_attributes”:将掩盖每个特征。
    • “attributes”:将屏蔽所有节点上的每个特征。
  • edge_mask_type——要应用于边的掩码类型。具有的可能值例如node_mask_type。默认为None
  • threshold_config——阈值设置,可选数值参见ThresholdConfig,默认为None

方法

  • get_prediction(*args, **kwargs)→ Tensor:返回模型对输入图的预测。
    如果模型模式为“regression”,则预测将作为标量值返回。如果模型模式为“multiclass_classification”或“binary_classifications”,则预测将作为预测类标签返回。
  • get_masked_prediction(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], node_mask: Optional[Union[Tensor, Dict[str, Tensor]]] = None, edge_mask: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs)→ Tensor:返回应用了节点和边掩码的输入图上的模型预测。
  • __call__(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]:计算给定输入和目标的GNN的解释。
    如果收到一条错误消息,如“Trying to backward through the graph a second time”,请确保提供的目标是用torch.no.grad()计算的。
    • x——通志图或异质图的输入节点特征。
    • edge_index——同质或异质图的输入边索引。
    • target——模型的目标。如果解释类型是“phenomenon”,则必须提供解释对象。如果解释类型是“ model”,那么目标应该设置为 Nothing,并且会自动推断出来。(默认值: None)
    • index——对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
    • **kwargs——要传递给GNN的其他参数。
  • get_target(prediction: Tensor)→ Tensor:从给定的预测中返回模型的目标。
    如果模型模式为“regression”类型,则按原样返回预测;如果模型模式类型为“multiclass_classification”或“binary_classifications”,则按预测类标签返回预测。

ExplainerConfig
用于存储和验证高级解释参数的配置类。
参数

  • explanation_type——要计算的解释类型。可能的值为:
    • “model”——解释模型预测。
    • “pheonmenon”——解释模型试图预测的现象。
      在实践中,这意味着解释算法将计算它们相对于模型输出(“model”)或目标输出(“pheonmenon”)的损失。
  • node_mask_type——要应用于节点的掩码类型。可能的值为(默认值:None):
    • “None”:不会在节点上应用任何掩码。
    • “object”:将屏蔽每个节点。
    • “common_attributes”:将掩盖每个特征。
    • “attributes”:将屏蔽所有节点上的每个特征。
  • edge_mask_type——要应用于边的掩码类型。具有的可能值例如node_mask_type。默认为None

ModelConfig
用于存储模型参数的配置类。
参数

  • model——模型的模式。可能的值为:
    • “binary_classification”:一个二分类模型。
    • “multiclass_classification”:一种多类分类模型。
    • “regression”:一个回归模型
  • task_level——模型的任务级别。可能的值为:
    • “node”:一个node-level预测模型
    • “edge”:一个edge-level预测模型
    • “graph”:一个graph-level预测模型
  • return_type——模型的返回类型。可能的值为(默认值:None):
    • “raw”:模型返回原始值。
    • “probs”:模型返回概率值
    • “log_probs”:模型返回对数概率

ThresholdConfig
用于存储和验证阈值参数的配置类。
参数

  • threshold_type——要应用的阈值的类型。可能的值为:
    • “None”:没有阈值被应用
    • “hard”:将hard阈值应用于每个掩码。掩码中值低于该值的元素设置为0,其他元素设置为1。
    • “topk”:soft阈值被应用于每个掩码。保留每个掩码的top obj:value元素,其他元素设置为0。
    • “topk_hard”:“topk”相同,但保留的所有元素的值都设置为1。
  • value——设置阈值时要使用的值。(默认值:None)

Explanations

Explanation
基础类:DataExplanationMixin
持有同质图的所有已得到的解释。解释对象是Data对象,可以包含节点属性和边属性。如果需要,它还可以保存原始图形。
参数:

  • node_mask——形状为[num_nodes, 1], [1, num_features][num_nodes, num_features]的node-level掩码,默认为None

  • edge_mask——形状为[num_edges]的edge_level掩码,默认为None

  • kwargs——其他属性参数
    方法

  • validate(raise_on_error: bool = True)→ bool:验证解释对象的正确性。

  • get_explanation_subgraph()→ Explanation:返回归纳子图,其中所有属性为零的节点和边都被屏蔽掉了。

  • get_complement_subgraph()→ Explanation:返回归纳子图,其中具有任何属性的所有节点和边都被屏蔽掉。

  • visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None):通过对所有节点的节点掩码求和,创建节点要素重要性的条形图。
    参数

    • path: 保存绘图的路径。如果设置为“None”,将动态显示绘图。(默认值:None)
    • feat_labels: 特征的标签。(默认为“None”)
    • top_k:绘制top k 个特征。如果None,绘制所有特征。(默认值: None)
  • visualize_graph(path: Optional[str] = None, backend: Optional[str] = None): 使具有与边重要性相对应的边不透明度的解释图可视化。
    参数

    • path: 保存绘图的路径。如果设置为“None”,将动态显示绘图。(默认值:None)
    • backend: 用于可视化的图形绘制后端(“graphviz”、“networkx”)。如果设置为“None”,将根据可用的系统包使用最合适的可视化后端。(默认值:None)

HeteroExplanation
基础类:HeteroDataExplanationMixin
包含所有已获得的对异构图的解释。解释对象是HeteroData对象,可以包含节点属性和边属性。如果需要,它还可以保存原始图形。
方法:

  • validate(raise_on_error: bool = True)→ bool:验证解释对象的正确性。
  • get_explanation_subgraph()→ HeteroExplanation:返回归纳子图,其中所有属性为零的节点和边都被屏蔽掉了。
  • get_complement_subgraph()→ HeteroExplanation:返回归纳子图,其中具有任何属性的所有节点和边都被屏蔽掉。
  • visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[Dict[str, List[str]]] = None, top_k: Optional[int] = None):通过对每个节点类型的所有节点的节点掩码求和,创建节点特征重要性的条形图。
    参数
    • path: 保存绘图的路径。如果设置为“None”,将动态显示绘图。(默认值:None)
    • feat_labels: 特征的标签。(默认为“None”)
    • top_k:绘制top k 个特征。如果None,绘制所有特征。(默认值: None)

Explainer Algorithm

1) ExplainerAlgorithm:用于实现解释器算法的抽象基类。
方法
- abstract forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]:计算解释。
参数
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- abstract supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。
- property explainer_config: ExplainerConfig:返回已连接的解释器配置。
- property model_config: ModelConfig:返回已连接的模型配置
- connect(explainer_config: ExplainerConfig, model_config: ModelConfig):将解释器和模型配置连接到解释器算法。
2) DummyExplainer:返回随机解释的伪解释程序(用于测试目的)。
基础类:ExplainerAlgorithm
方法
- forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], edge_attr: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]:解释计算
参数:
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。

  • GNNExplainer:来自 “GNNExplainer: Generating Explanations for Graph Neural Networks” 论文中的GNN-Explainer模型用于识别在GNN预测中起关键作用的紧凑子图结构和节点特征。
    基础类:ExplainerAlgorithm
    有关使用GNNEexplainer的示例,请参见examples/explaine/gnn_explainer.py、examples/explain/gnn_eexplainer_ba_shapes.py和examples/explain/gn_explainer_link_pred.py。
    参数

    • epochs:要训练的epochs数。默认为100
    • lr:学习率,默认为0.01
    • kwargs:用于覆盖coeffs中默认设置的附加超参数。

    方法:

    • forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation:计算解释
    • supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。
      3) CaptumExplainer:一种基于Captum的解释器,用于识别在GNN的预测中起关键作用的紧凑子图结构和节点特征。
      基础类:ExplainerAlgorithm
      这个解释器算法使用Captum来计算属性。目前,支持以下归因方法:
    • captum.attr.IntegratedGradients
    • captum.attr.Saliency
    • captum.attr.InputXGradient
    • captum.attr.Deconvolution
    • captum.attr.ShapleyValueSampling
    • captum.attr.GuidedBackprop
      参数
    • attribution_method:要使用的Captum归因方法。可以是字符串或captum.attr方法。
    • kwargs:Captum归因方法的其他参数。
      方法
    • forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Union[Explanation, HeteroExplanation]:计算解释
    • supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。
      4) PGExplainer: "Parameterized Explainer for Graph Neural Network"论文中的PGExplainer模型。
      基础类:ExplainerAlgorithm
      在内部,它利用神经网络来识别在GNN的预测中起关键作用的子图结构。重要的是,在生成解释之前,PGExplainer需要通过train()进行训练:
explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=ModelConfig(...),
)

# 针对各种节点级别或图级别的预测进行训练:
for epoch in range(30):
    for index in [...]:  # Indices to train against.
        loss = explainer.algorithm.train(epoch, model, x, edge_index,
                                         target=target, index=index)

# 获得最终解释:
explanation = explainer(x, edge_index, target=target, index=0)

参数
- epochs:要训练的epochs数。
- lr:学习率,默认为0.003
- kwargs:用于覆盖coeffs中默认设置的附加超参数。
方法
- reset_parameters():重置模型中所有科学系参数
- train(epoch: int, model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs):训练基础的解释者模型。需要在能够做出预测之前被调用。
参数
- epoch:训练阶段的当前阶段。
- model:要被解释的模型
- x:同质图的输入节点特征。
- edge_index:同质图的输入边索引。
- target:模型的目标
- index:对模型输出的索引进行解释。需要是一个单独的索引。(默认值:None)
- kwargs:传递给model的其他关键字参数。
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation:计算解释
参数:
- model:要解释的模型
- x:一个同质图或异质图的输入节点特征
- edge_index:一个同质图或异质图的输入边索引
- target:模型的目标
- index:对模型输出的索引进行解释。可以是单个索引或索引的张量。(默认值:None)
- kwargs:传递给 model 的其他关键字参数。
- supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。
5) AttentionExplainer:使用基于注意力的GNN(例如,GATConv、GATv2Conv或TransformerConv)产生的注意力系数作为边解释的解释器。各层和头部的注意力得分将根据reduce argument进行汇总。
基础类:ExplainerAlgorithm
参数:
- reduce:降低各层和头部注意力得分的方法。(默认值:“max”)
方法:
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)→ Explanation:计算解释
- - supports()→ bool:检查解释器是否支持self.explainer_configself.model_config中提供的用户定义设置。

Explanation Metrics

解释的质量可以通过各种不同的方法来判断。PyG支持以下开箱即用的指标:

  • groundtruth_metrics:将解释掩码与ground-truth解释掩码进行比较和评估。
  • fidelity:评估一个Explainer给出的Explanation的真实度 ,参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文
  • characterization_score:返回组件式特征化分数(the componentwise characterization score),参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文
  • fidelity_curve_auc:返回真实度曲线的AUC,参见 “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” 论文
  • unfaithfulness:评估一个Explanation对一个不足的GNN预测因子的真实度,参见 "Evaluating Explainability for Graph Neural Networks"论文

参考资料

  1. 可解释性研究(四)-GNNExplainer的内部实现
  2. GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks
  3. torch_geometric.explain官方文档
  4. GNNExplainer: Generating Explanations for Graph Neural Networks
  5. Captum
  6. PGExplainer
  7. Evaluating Explainability for Graph Neural Networks

你可能感兴趣的:(可解释机器学习,pytorch,可解释学习,PyG,explainer,GNNExplainer,PGExplainer,CaptumExplainer)