之前提到过GNNExplainer的论文,但是论文中的一大堆公式很难让人摸着边,所以上GitHub找了一些GNNExplainer的实现,有
GNNExplainer会从2个角度解释图:
topk
的edge连成的子图来解释。这里以DIG的为基础,贴上一个精简版的explain.py。
import torch
from torch import Tensor
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from math import sqrt
from configuration import data_args
from torch_geometric.data import Batch, Data
from torch.nn.functional import cross_entropy
class ExplainerBase(nn.Module):
def __init__(self, model: nn.Module, epochs=0, lr=0, explain_graph=False, molecule=False):
super().__init__()
self.model = model
self.lr = lr
self.epochs = epochs
self.explain_graph = explain_graph
self.molecule = molecule
self.mp_layers = [module for module in self.model.modules() if isinstance(module, MessagePassing)]
self.num_layers = len(self.mp_layers)
self.ori_pred = None
self.ex_labels = None
self.edge_mask = None
self.hard_edge_mask = None
self.num_edges = None
self.num_nodes = None
self.device = None
def __set_masks__(self, x, edge_index, init="normal"):
(N, F), E = x.size(), edge_index.size(1)
std = 0.1
self.node_feat_mask = torch.nn.Parameter(torch.randn(F, requires_grad=True, device=self.device) * 0.1)
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
self.edge_mask = torch.nn.Parameter(torch.randn(E, requires_grad=True, device=self.device) * std)
# self.edge_mask = torch.nn.Parameter(100 * torch.ones(E, requires_grad=True))
for module in self.model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = True
module.__edge_mask__ = self.edge_mask
def __clear_masks__(self):
for module in self.model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = False
module.__edge_mask__ = None
self.node_feat_masks = None
self.edge_mask = None
@property
def __num_hops__(self):
if self.explain_graph:
return -1
else:
return self.num_layers
def __flow__(self):
for module in self.model.modules():
if isinstance(module, MessagePassing):
return module.flow
return 'source_to_target'
def forward(self,
x: Tensor,
edge_index: Tensor,
**kwargs
):
self.num_edges = edge_index.shape[1]
self.num_nodes = x.shape[0]
self.device = x.device
def eval_related_pred(self, x, edge_index, edge_masks, **kwargs):
node_idx = kwargs.get('node_idx')
node_idx = 0 if node_idx is None else node_idx # graph level: 0, node level: node_idx
related_preds = []
for ex_label, edge_mask in enumerate(edge_masks):
self.edge_mask.data = float('inf') * torch.ones(edge_mask.size(), device=data_args.device)
ori_pred = self.model(x=x, edge_index=edge_index, **kwargs)
self.edge_mask.data = edge_mask
masked_pred = self.model(x=x, edge_index=edge_index, **kwargs)
# mask out important elements for fidelity calculation
self.edge_mask.data = - edge_mask # keep Parameter's id
maskout_pred = self.model(x=x, edge_index=edge_index, **kwargs)
# zero_mask
self.edge_mask.data = - float('inf') * torch.ones(edge_mask.size(), device=data_args.device)
zero_mask_pred = self.model(x=x, edge_index=edge_index, **kwargs)
related_preds.append({'zero': zero_mask_pred[node_idx],
'masked': masked_pred[node_idx],
'maskout': maskout_pred[node_idx],
'origin': ori_pred[node_idx]})
return related_preds
EPS = 1e-15
class GNNExplainer(ExplainerBase):
r"""The GNN-Explainer model from the `"GNNExplainer: Generating
Explanations for Graph Neural Networks"
`_ paper for identifying compact subgraph
structures and small subsets node features that play a crucial role in a
GNN’s node-predictions.
.. note::
For an example of using GNN-Explainer, see `examples/gnn_explainer.py
`_.
Args:
model (torch.nn.Module): The GNN module to explain.
epochs (int, optional): The number of epochs to train.
(default: :obj:`100`)
lr (float, optional): The learning rate to apply.
(default: :obj:`0.01`)
log (bool, optional): If set to :obj:`False`, will not log any learning
progress. (default: :obj:`True`)
"""
coeffs = {
'edge_size': 0.005,
'node_feat_size': 1.0,
'edge_ent': 1.0,
'node_feat_ent': 0.1,
}
def __init__(self, model, epochs=50, lr=0.001, explain_graph=True, molecule=False):
super(GNNExplainer, self).__init__(model, epochs, lr, explain_graph, molecule)
def __loss__(self, raw_preds, x_label):
loss = cross_entropy(raw_preds, x_label)
m = self.edge_mask.sigmoid()
loss = loss + self.coeffs['edge_size'] * m.sum()
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['edge_ent'] * ent.mean()
if self.mask_features:
m = self.node_feat_mask.sigmoid()
loss = loss + self.coeffs['node_feat_size'] * m.sum()
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['node_feat_ent'] * ent.mean()
return loss
def gnn_explainer_alg(self, x: Tensor, edge_index: Tensor, ex_label: Tensor, mask_features: bool = False, **kwargs) -> None:
# initialize a mask
patience = 10
self.to(x.device)
self.mask_features = mask_features
# train to get the mask
optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
lr=self.lr)
best_loss = 4.0
count = 0
for epoch in range(1, self.epochs + 1):
if mask_features:
h = x * self.node_feat_mask.view(1, -1).sigmoid()
else:
h = x
raw_preds = self.model(data=Batch.from_data_list([Data(x=h, edge_index=edge_index)]))
loss = self.__loss__(raw_preds, ex_label)
# if epoch % 10 == 0:
# print(f'#D#Loss:{loss.item()}')
is_best = (loss < best_loss)
if not is_best:
count += 1
else:
count = 0
best_loss = loss
if count >= patience:
break
optimizer.zero_grad()
loss.backward()
optimizer.step()
return self.edge_mask.data
def forward(self, x, edge_index, mask_features=False,
positive=True, **kwargs):
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for node
:attr:`node_idx`.
Args:
data (Batch): batch from dataloader
edge_index (LongTensor): The edge indices.
pos_neg (Literal['pos', 'neg']) : get positive or negative mask
**kwargs (optional): Additional arguments passed to the GNN module.
:rtype: (:class:`Tensor`, :class:`Tensor`)
"""
self.model.eval()
# self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes)
# Only operate on a k-hop subgraph around `node_idx`.
# Calculate mask
ex_label = torch.tensor([1]).to(data_args.device)
self.__clear_masks__()
self.__set_masks__(x, edge_index)
edge_mask = self.gnn_explainer_alg(x, edge_index, ex_label)
# edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label))
# with torch.no_grad():
# related_preds = self.eval_related_pred(x, edge_index, edge_masks, **kwargs)
self.__clear_masks__()
sorted_results = edge_mask.sort(descending=True)
return edge_mask.detach(), sorted_results.indices.cpu(), edge_index.cpu()
def __repr__(self):
return f'{self.__class__.__name__}()'
入口函数是GNNExplainer.forward
,这里我只解释被分类为1的样本。
def forward(self, x, edge_index, mask_features=False,
positive=True, **kwargs):
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for node
:attr:`node_idx`.
Args:
data (Batch): batch from dataloader
edge_index (LongTensor): The edge indices.
pos_neg (Literal['pos', 'neg']) : get positive or negative mask
**kwargs (optional): Additional arguments passed to the GNN module.
:rtype: (:class:`Tensor`, :class:`Tensor`)
"""
self.model.eval()
# self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes)
# Only operate on a k-hop subgraph around `node_idx`.
# Calculate mask
ex_label = torch.tensor([1]).to(data_args.device)
self.__clear_masks__()
self.__set_masks__(x, edge_index)
edge_mask = self.gnn_explainer_alg(x, edge_index, ex_label)
# edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label))
# with torch.no_grad():
# related_preds = self.eval_related_pred(x, edge_index, edge_masks, **kwargs)
self.__clear_masks__()
sorted_results = edge_mask.sort(descending=True)
return edge_mask.detach(), sorted_results.indices.cpu(), edge_index.cpu()
函数首先会调用__clear_masks__()
清除之前的edge mask(虽然没有),然后调用__set_masks__
设定一个初始随机生成的edge mask。
函数会返回edge_mask
和根据edge_mask
排序后的edge。计算edge mask调用了gnn_explainer_alg
。
这个代码了解下就好,主要是设置初始随机生成的edge mask和NF mask。但是我为了简化代码删除了NF mask的设置。完整版可以参考DIG的代码。
def __set_masks__(self, x, edge_index, init="normal"):
(N, F), E = x.size(), edge_index.size(1)
std = 0.1
self.node_feat_mask = torch.nn.Parameter(torch.randn(F, requires_grad=True, device=self.device) * 0.1)
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
self.edge_mask = torch.nn.Parameter(torch.randn(E, requires_grad=True, device=self.device) * std)
# self.edge_mask = torch.nn.Parameter(100 * torch.ones(E, requires_grad=True))
for module in self.model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = True
module.__edge_mask__ = self.edge_mask
主要计算一个最优的edge mask,NF mask先省略。
def gnn_explainer_alg(self, x: Tensor, edge_index: Tensor, ex_label: Tensor, mask_features: bool = False, **kwargs) -> None:
# initialize a mask
patience = 10
self.to(x.device)
self.mask_features = mask_features
# train to get the mask
optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
lr=self.lr)
best_loss = 4.0
count = 0
for epoch in range(1, self.epochs + 1):
if mask_features:
h = x * self.node_feat_mask.view(1, -1).sigmoid()
else:
h = x
raw_preds = self.model(data=Batch.from_data_list([Data(x=h, edge_index=edge_index)]))
loss = self.__loss__(raw_preds, ex_label)
# if epoch % 10 == 0:
# print(f'#D#Loss:{loss.item()}')
is_best = (loss < best_loss)
if not is_best:
count += 1
else:
count = 0
best_loss = loss
if count >= patience:
break
optimizer.zero_grad()
loss.backward()
optimizer.step()
return self.edge_mask.data
这里将edge mask(还有NF mask)作为可被训练的参数用神经网络训练。loss函数如下
def __loss__(self, raw_preds, x_label):
loss = cross_entropy(raw_preds, x_label)
m = self.edge_mask.sigmoid()
loss = loss + self.coeffs['edge_size'] * m.sum()
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['edge_ent'] * ent.mean()
if self.mask_features:
m = self.node_feat_mask.sigmoid()
loss = loss + self.coeffs['node_feat_size'] * m.sum()
ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
loss = loss + self.coeffs['node_feat_ent'] * ent.mean()
return loss
在loss
中这个raw_preds
已经是设置了mask之后的计算结果,因为整个过程都没有再调用clear_mask
来清除label。
loss由3部分构成:
sum
)整个loss大概可以用如下表达式表示(不加NF mask)
l o s s = C r o s s E n t r o p y ( f ( d a t a , m o d e l , e d g e _ m a s k ) , l a b e l ) + S i z e ( e d g e _ m a s k ) + D i s c r e t e ( e d g e _ m a s k ) loss = CrossEntropy(f(data, model, edge\_mask), label) + Size(edge\_mask) + Discrete(edge\_mask) loss=CrossEntropy(f(data,model,edge_mask),label)+Size(edge_mask)+Discrete(edge_mask)
最后就可以得到一个最优(loss最小)的edge mask。那么就含剩下一个问题用户代码没能回答,就是在 f ( d a t a , m o d e l , e d g e _ m a s k ) f(data,model,edge\_mask) f(data,model,edge_mask) 中 edge mask是如何起作用的。也就是添加edge mask是如何影响model的运算过程的。要回答这个问题,就必须探索torch geometric是如何支持GNNExplainer的。
这里的核心类是MessagePassing
,其propagate
函数不支持GNNExplainer版本的代码如下:
coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs)
msg_kwargs = self.inspector.distribute('message', coll_dict)
out = self.message(**msg_kwargs)
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)
message
函数在示例中直接返回了x_j
,update
同理直接返回输入tensor
。aggregate
为聚合周围结点的信息。
out = self.message(**msg_kwargs)
这里返回的out
维度为[edge_num, node_feature_dim]
。应该是每条边目标结点的向量表示。out = self.aggregate(out, **aggr_kwargs)
这里返回的out
维度为[node_num, node_feature_dim]。
应该是聚合完毕后每个结点的向量表示。添加了GNNExplainer支持后的propagate函数代码如下:
coll_dict = self.__collect__(self.__user_args__, edge_index, size,
kwargs)
msg_kwargs = self.inspector.distribute('message', coll_dict)
out = self.message(**msg_kwargs)
# For `GNNExplainer`, we require a separate message and aggregate
# procedure since this allows us to inject the `edge_mask` into the
# message passing computation scheme.
if self.__explain__:
edge_mask = self.__edge_mask__.sigmoid()
# Some ops add self-loops to `edge_index`. We need to do the
# same for `edge_mask` (but do not train those).
if out.size(self.node_dim) != edge_mask.size(0):
loop = edge_mask.new_ones(size[0])
edge_mask = torch.cat([edge_mask, loop], dim=0)
assert out.size(self.node_dim) == edge_mask.size(0)
out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)
可以看到在message和aggregate之间插入了一段代码,忽略判断语句以及形状修改,内容就是out = out * edge_mask
。也就是对于结点 x i x_i xi,在其与周围结点聚合的同时,先将周围结点(只算入度,就是存在有向边 ( u , v ) (u, v) (u,v) 那么 u u u 算 v v v 的周围结点)的向量乘以edge mask。(这里的理解可能有误,欢迎大佬指正)