最近在看关于图神经网络相关论文,涉及到论文复现,所以要啃源码。
网上虽然有一部分对源码的解析,如torch_geometric()方法,但是在深入到里面具体方法,资料寥寥无几。
特将自己学习笔记分享,一来方便自己以后查阅,二来给和我一样苦难的同胞一点点思路,三来方便热心网友补充,最后为后面大神分享代码或思路抛砖引玉。^_^
torch_geometric\nn\conv\message_passing.py
import inspect
from collections import OrderedDict
from types import MappingProxyType
import torch
from torch_geometric.utils import scatter_
# 预先定义特殊message方法参数
msg_special_args = set([
'edge_index',
'edge_index_i',
'edge_index_j',
'size',
'size_i',
'size_j',
])
# 预先定义特殊aggregate方法参数
aggr_special_args = set([
'index',
'dim_size',
])
# 预先定义特殊update方法参数
update_special_args = set([])
其是用于创建消息传递层的基类。
公式为:
x i ′ = γ Θ ( x i , □ j ∈ N ( i ) ϕ Θ ( x i , x j , e i , j ) ) \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right) xi′=γΘ(xi,□j∈N(i)ϕΘ(xi,xj,ei,j))
□ j ∈ N ( i ) \square_{j \in \mathcal{N}(i)} □j∈N(i)表示聚合方法,如求和、取平均和最大值,self.aggregate执行;
γ Θ \gamma_{\mathbf{\Theta}} γΘ表示自己的信息与聚合后的邻居信息的变换,self.update执行;
ϕ Θ \phi_{\mathbf{\Theta}} ϕΘ表示邻居信息的变换,self.message执行;
更多详细信息请参见官网。
class MessagePassing(torch.nn.Module)
"""
Args:
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`).
(default: :obj:`"add"`)
flow (string, optional): The flow direction of message passing
(:obj:`"source_to_target"` or :obj:`"target_to_source"`).
(default: :obj:`"source_to_target"`)
node_dim (int, optional): The axis along which to propagate.
(default: :obj:`0`)
"""
def __init__(self, aggr='add', flow='source_to_target', node_dim=0):
super(MessagePassing, self).__init__()
# 聚合方式
self.aggr = aggr
assert self.aggr in ['add', 'mean', 'max']
self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source']
self.node_dim = node_dim
assert self.node_dim >= 0
# 存储重写message方法中的参数
self.__msg_params__ = inspect.signature(self.message).parameters
# 存储重写aggregate方法中的参数
self.__aggr_params__ = inspect.signature(self.aggregate).parameters
# 将不可变映射类型转为有序字典
self.__aggr_params__ = OrderedDict(self.__aggr_params__)
# 丢弃第一个键值对,('inputs', )
self.__aggr_params__.popitem(last=False)
# 重新转化为MappingProxyType
self.__aggr_params__ = MappingProxyType(self.__aggr_params__)
# 存储重写update方法中的参数
self.__update_params__ = inspect.signature(self.update).parameters
self.__update_params__ = OrderedDict(self.__update_params__)
self.__update_params__.popitem(last=False)
self.__update_params__ = MappingProxyType(self.__update_params__)
# 除去预定义参数
msg_args = set(self.__msg_params__.keys()) - msg_special_args
aggr_args = set(self.__aggr_params__.keys()) - aggr_special_args
update_args = set(self.__update_params__.keys()) - update_special_args
# 将其合并
self.__args__ = set().union(msg_args, aggr_args, update_args)
# 保证数据处理前维度统一
def __set_size__(self, size, index, tensor):
if not torch.is_tensor(tensor):
pass
elif size[index] is None:
size[index] = tensor.size(self.node_dim)
elif size[index] != tensor.size(self.node_dim):
raise ValueError(
(f'Encountered node tensor with size '
f'{tensor.size(self.node_dim)} in dimension {self.node_dim}, '
f'but expected size {size[index]}.'))
# 将所有可能以后用的到参数均初始化
def __collect__(self, edge_index, size, kwargs):
# edge_index has shape [2, E],边是两个node之间的关系,所以是2,E是边个数
# 消息传递流向,默认source_to_target,i=1,j=0
i, j = (0, 1) if self.flow == "target_to_source" else (1, 0)
# ij为字典
ij = {"_i": i, "_j": j}
out = {}
# 依次处理message、aggregate和update中的参数并存入out字典中
for arg in self.__args__:
# 处理最后两个字符
if arg[-2:] not in ij.keys():
out[arg] = kwargs.get(arg, inspect.Parameter.empty)
else:
# 取出0
idx = ij[arg[-2:]]
# 获取字典kwargs中arg[:-2]数据,否则返回空
data = kwargs.get(arg[:-2], inspect.Parameter.empty)
# 判定data,如果为空,直接赋值为空,继续下一批数据
if data is inspect.Parameter.empty:
out[arg] = data
continue
#是tuple或者list类型,进行如下处理
if isinstance(data, tuple) or isinstance(data, list):
assert len(data) == 2
self.__set_size__(size, 1 - idx, data[1 - idx])
data = data[idx]
# 不是tensor类型,进行如下处理
if not torch.is_tensor(data):
out[arg] = data
continue
# 保证数据处理前维度统一
self.__set_size__(size, idx, data)
# torch中取出相应数据,第一个参数是维度,第二个是索引
out[arg] = data.index_select(self.node_dim, edge_index[idx])
# size不为空则为本身,否则交换次序
size[0] = size[1] if size[0] is None else size[0]
size[1] = size[0] if size[1] is None else size[1]
# 添加特殊消息参数
out['edge_index'] = edge_index
out['edge_index_i'] = edge_index[i]
out['edge_index_j'] = edge_index[j]
out['size'] = size
out['size_i'] = size[i]
out['size_j'] = size[j]
# 添加特殊消息参数.
out['index'] = out['edge_index_i']
out['dim_size'] = out['size_i']
return out
# 将之前传入参数数据依次赋值
def __distribute__(self, params, kwargs):
out = {}
for key, param in params.items():
data = kwargs[key]
# 检验是否赋值
if data is inspect.Parameter.empty:
# 若所需数据空缺则报错
if param.default is inspect.Parameter.empty:
raise TypeError(f'Required parameter {key} is empty.')
# 赋默认值
data = param.default
# 存入字典
out[key] = data
return out
前期准备所用到的参数,后期依次调用self.message、self.aggregate和self.update方法。
def propagate(self, edge_index, size=None, **kwargs):
#初始调用以开始传播消息
"""
Args:
edge_index (Tensor): The indices of a general (sparse) assignment
matrix with shape :obj:`[N, M]` (can be directed or
undirected).
size (list or tuple, optional): The size :obj:`[N, M]` of the
assignment matrix. If set to :obj:`None`, the size will be
automatically inferred and assumed to be quadratic.
(default: :obj:`None`)
**kwargs: Any additional data which is needed to construct and
aggregate messages, and to update node embeddings.
"""
#保证最后size长度为2且类型为list
size = [None, None] if size is None else size
size = [size, size] if isinstance(size, int) else size
size = size.tolist() if torch.is_tensor(size) else size
size = list(size) if isinstance(size, tuple) else size
assert isinstance(size, list) # assert 断言,等价于if not expression: raise AssertionError
assert len(size) == 2
# 准备好一切可能用到的参数
kwargs = self.__collect__(edge_index, size, kwargs)
# 依据需要参数数据从kwargs中获取
msg_kwargs = self.__distribute__(self.__msg_params__, kwargs)
# 双星号表示转化为字典输入
out = self.message(**msg_kwargs)
aggr_kwargs = self.__distribute__(self.__aggr_params__, kwargs)
out = self.aggregate(out, **aggr_kwargs)
update_kwargs = self.__distribute__(self.__update_params__, kwargs)
out = self.update(out, **update_kwargs)
return out
def message(self, x_j): # pragma: no cover
r"""Constructs messages to node :math:`i` in analogy to
:math:`\phi_{\mathbf{\Theta}}` for each edge in
:math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and
:math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`.
Can take any argument which was initially passed to :meth:`propagate`.
In addition, tensors passed to :meth:`propagate` can be mapped to the
respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
:obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
"""
return x_j
def aggregate(self, inputs, index, dim_size): # pragma: no cover
r"""Aggregates messages from neighbors as
:math:`\square_{j \in \mathcal{N}(i)}`.
By default, delegates call to scatter functions that support
"add", "mean" and "max" operations specified in :meth:`__init__` by
the :obj:`aggr` argument.
"""
# 专门开一个博客讲(https://blog.csdn.net/qq_39407949/article/details/116855426)
#一句话————按传入方式聚合节点信息
return scatter_(self.aggr, inputs, index, self.node_dim, dim_size)
def update(self, inputs): # pragma: no cover
r"""Updates node embeddings in analogy to
:math:`\gamma_{\mathbf{\Theta}}` for each node
:math:`i \in \mathcal{V}`.
Takes in the output of aggregation as first argument and any argument
which was initially passed to :meth:`propagate`.
"""
return inputs
部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
官方说明 Creating Message Passing Networks
https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
参考博文1 Graph Convolutional Network (图卷积GCN)
https://blog.csdn.net/qq_39388410/article/details/102730998
参考博文2 CS224w 图神经网络(Graph Neural Networks)
https://zhuanlan.zhihu.com/p/113862170
参考博文3 关于Pytorch和Pytorch Geometric(PyG)框架下重现GCN代码的理解
https://blog.csdn.net/meteor_0033/article/details/104205835
参考博文4 Torch geometric NNConv 源码分析
https://blog.csdn.net/qq_41987033/article/details/103497749
参考博文5 如何理解 Graph Convolutional Network(GCN)?
https://www.zhihu.com/question/54504471/answer/332657604
参考博文6 从图(Graph)到图卷积(Graph Convolution):漫谈图神经网络模型 (系列)
https://www.cnblogs.com/SivilTaram/p/graph_neural_network_1.html
参考博文7 torch 默认参数初始化_torch_geometric 源码阅读(一)
https://blog.csdn.net/weixin_39946798/article/details/111640849