GCN torch_geometric message_passing方法源码解析

message_passing方法源码解析

  • 前言
  • 源码位置
  • 源码及注释
    • MessagePassing类
    • propagate方法
  • 感谢及参考博文

前言

最近在看关于图神经网络相关论文,涉及到论文复现,所以要啃源码。
网上虽然有一部分对源码的解析,如torch_geometric()方法,但是在深入到里面具体方法,资料寥寥无几。
特将自己学习笔记分享,一来方便自己以后查阅,二来给和我一样苦难的同胞一点点思路,三来方便热心网友补充,最后为后面大神分享代码或思路抛砖引玉。^_^

源码位置

torch_geometric\nn\conv\message_passing.py

源码及注释

import inspect
from collections import OrderedDict
from types import MappingProxyType

import torch
from torch_geometric.utils import scatter_

# 预先定义特殊message方法参数
msg_special_args = set([
    'edge_index',
    'edge_index_i',
    'edge_index_j',
    'size',
    'size_i',
    'size_j',
])
# 预先定义特殊aggregate方法参数
aggr_special_args = set([
    'index',
    'dim_size',
])
# 预先定义特殊update方法参数
update_special_args = set([])

MessagePassing类

其是用于创建消息传递层的基类。
公式为:

x i ′ = γ Θ ( x i , □ j ∈ N ( i )   ϕ Θ ( x i , x j , e i , j ) ) \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right) xi=γΘ(xi,jN(i)ϕΘ(xi,xj,ei,j))

□ j ∈ N ( i ) \square_{j \in \mathcal{N}(i)} jN(i)表示聚合方法,如求和、取平均和最大值,self.aggregate执行;
γ Θ \gamma_{\mathbf{\Theta}} γΘ表示自己的信息与聚合后的邻居信息的变换,self.update执行;
ϕ Θ \phi_{\mathbf{\Theta}} ϕΘ表示邻居信息的变换,self.message执行;
更多详细信息请参见官网。

class MessagePassing(torch.nn.Module)
    """
    Args:
        aggr (string, optional): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`).
            (default: :obj:`"add"`)
        flow (string, optional): The flow direction of message passing
            (:obj:`"source_to_target"` or :obj:`"target_to_source"`).
            (default: :obj:`"source_to_target"`)
        node_dim (int, optional): The axis along which to propagate.
            (default: :obj:`0`)
    """
    def __init__(self, aggr='add', flow='source_to_target', node_dim=0):
        super(MessagePassing, self).__init__()
        # 聚合方式
        self.aggr = aggr
        assert self.aggr in ['add', 'mean', 'max']

        self.flow = flow
        assert self.flow in ['source_to_target', 'target_to_source']

        self.node_dim = node_dim
        assert self.node_dim >= 0
        # 存储重写message方法中的参数
        self.__msg_params__ = inspect.signature(self.message).parameters

        # 存储重写aggregate方法中的参数
        self.__aggr_params__ = inspect.signature(self.aggregate).parameters
        # 将不可变映射类型转为有序字典
        self.__aggr_params__ = OrderedDict(self.__aggr_params__)
        # 丢弃第一个键值对,('inputs', )
        self.__aggr_params__.popitem(last=False)
        # 重新转化为MappingProxyType
        self.__aggr_params__ = MappingProxyType(self.__aggr_params__)

        # 存储重写update方法中的参数
        self.__update_params__ = inspect.signature(self.update).parameters
        self.__update_params__ = OrderedDict(self.__update_params__)
        self.__update_params__.popitem(last=False)
        self.__update_params__ = MappingProxyType(self.__update_params__)

        # 除去预定义参数
        msg_args = set(self.__msg_params__.keys()) - msg_special_args
        aggr_args = set(self.__aggr_params__.keys()) - aggr_special_args
        update_args = set(self.__update_params__.keys()) - update_special_args

        # 将其合并
        self.__args__ = set().union(msg_args, aggr_args, update_args)

    # 保证数据处理前维度统一
    def __set_size__(self, size, index, tensor):
        if not torch.is_tensor(tensor):
            pass
        elif size[index] is None:
            size[index] = tensor.size(self.node_dim)
        elif size[index] != tensor.size(self.node_dim):
            raise ValueError(
                (f'Encountered node tensor with size '
                 f'{tensor.size(self.node_dim)} in dimension {self.node_dim}, '
                 f'but expected size {size[index]}.'))

    # 将所有可能以后用的到参数均初始化
    def __collect__(self, edge_index, size, kwargs):
        # edge_index has shape [2, E],边是两个node之间的关系,所以是2,E是边个数
        # 消息传递流向,默认source_to_target,i=1,j=0
        i, j = (0, 1) if self.flow == "target_to_source" else (1, 0)
        # ij为字典
        ij = {"_i": i, "_j": j}

        out = {}
        # 依次处理message、aggregate和update中的参数并存入out字典中
        for arg in self.__args__:
            # 处理最后两个字符
            if arg[-2:] not in ij.keys():
                out[arg] = kwargs.get(arg, inspect.Parameter.empty)
            else:
                # 取出0
                idx = ij[arg[-2:]]
                # 获取字典kwargs中arg[:-2]数据,否则返回空
                data = kwargs.get(arg[:-2], inspect.Parameter.empty)

                # 判定data,如果为空,直接赋值为空,继续下一批数据
                if data is inspect.Parameter.empty:
                    out[arg] = data
                    continue

                #是tuple或者list类型,进行如下处理
                if isinstance(data, tuple) or isinstance(data, list):
                    assert len(data) == 2
                    self.__set_size__(size, 1 - idx, data[1 - idx])
                    data = data[idx]

                # 不是tensor类型,进行如下处理
                if not torch.is_tensor(data):
                    out[arg] = data
                    continue

                # 保证数据处理前维度统一
                self.__set_size__(size, idx, data)
                # torch中取出相应数据,第一个参数是维度,第二个是索引
                out[arg] = data.index_select(self.node_dim, edge_index[idx])

        # size不为空则为本身,否则交换次序
        size[0] = size[1] if size[0] is None else size[0]
        size[1] = size[0] if size[1] is None else size[1]

        # 添加特殊消息参数
        out['edge_index'] = edge_index
        out['edge_index_i'] = edge_index[i]
        out['edge_index_j'] = edge_index[j]
        out['size'] = size
        out['size_i'] = size[i]
        out['size_j'] = size[j]

        # 添加特殊消息参数.
        out['index'] = out['edge_index_i']
        out['dim_size'] = out['size_i']

        return out

    # 将之前传入参数数据依次赋值
    def __distribute__(self, params, kwargs):
        out = {}
        for key, param in params.items():
            data = kwargs[key]
            # 检验是否赋值
            if data is inspect.Parameter.empty:
                # 若所需数据空缺则报错
                if param.default is inspect.Parameter.empty:
                    raise TypeError(f'Required parameter {key} is empty.')
                # 赋默认值
                data = param.default
            # 存入字典
            out[key] = data
        return out

propagate方法

前期准备所用到的参数,后期依次调用self.message、self.aggregate和self.update方法。

    def propagate(self, edge_index, size=None, **kwargs):
        #初始调用以开始传播消息
 		"""
        Args:
            edge_index (Tensor): The indices of a general (sparse) assignment
                matrix with shape :obj:`[N, M]` (can be directed or
                undirected).
            size (list or tuple, optional): The size :obj:`[N, M]` of the
                assignment matrix. If set to :obj:`None`, the size will be
                automatically inferred and assumed to be quadratic.
                (default: :obj:`None`)
            **kwargs: Any additional data which is needed to construct and
                aggregate messages, and to update node embeddings.
        """
        #保证最后size长度为2且类型为list
        size = [None, None] if size is None else size
        size = [size, size] if isinstance(size, int) else size
        size = size.tolist() if torch.is_tensor(size) else size
        size = list(size) if isinstance(size, tuple) else size
        assert isinstance(size, list)  # assert 断言,等价于if not expression:    raise AssertionError
        assert len(size) == 2

        # 准备好一切可能用到的参数
        kwargs = self.__collect__(edge_index, size, kwargs)

        # 依据需要参数数据从kwargs中获取
        msg_kwargs = self.__distribute__(self.__msg_params__, kwargs)
        # 双星号表示转化为字典输入
        out = self.message(**msg_kwargs)

        aggr_kwargs = self.__distribute__(self.__aggr_params__, kwargs)
        out = self.aggregate(out, **aggr_kwargs)

        update_kwargs = self.__distribute__(self.__update_params__, kwargs)
        out = self.update(out, **update_kwargs)

        return out

    def message(self, x_j):  # pragma: no cover
        r"""Constructs messages to node :math:`i` in analogy to
        :math:`\phi_{\mathbf{\Theta}}` for each edge in
        :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and
        :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`.
        Can take any argument which was initially passed to :meth:`propagate`.
        In addition, tensors passed to :meth:`propagate` can be mapped to the
        respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
        :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
        """

        return x_j

    def aggregate(self, inputs, index, dim_size):  # pragma: no cover
        r"""Aggregates messages from neighbors as
        :math:`\square_{j \in \mathcal{N}(i)}`.

        By default, delegates call to scatter functions that support
        "add", "mean" and "max" operations specified in :meth:`__init__` by
        the :obj:`aggr` argument.
        """
        # 专门开一个博客讲(https://blog.csdn.net/qq_39407949/article/details/116855426)
       #一句话————按传入方式聚合节点信息
        return scatter_(self.aggr, inputs, index, self.node_dim, dim_size)

    def update(self, inputs):  # pragma: no cover
        r"""Updates node embeddings in analogy to
        :math:`\gamma_{\mathbf{\Theta}}` for each node
        :math:`i \in \mathcal{V}`.
        Takes in the output of aggregation as first argument and any argument
        which was initially passed to :meth:`propagate`.
        """

        return inputs

感谢及参考博文

部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
官方说明 Creating Message Passing Networks
https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
参考博文1 Graph Convolutional Network (图卷积GCN)
https://blog.csdn.net/qq_39388410/article/details/102730998
参考博文2 CS224w 图神经网络(Graph Neural Networks)
https://zhuanlan.zhihu.com/p/113862170
参考博文3 关于Pytorch和Pytorch Geometric(PyG)框架下重现GCN代码的理解
https://blog.csdn.net/meteor_0033/article/details/104205835
参考博文4 Torch geometric NNConv 源码分析
https://blog.csdn.net/qq_41987033/article/details/103497749
参考博文5 如何理解 Graph Convolutional Network(GCN)?
https://www.zhihu.com/question/54504471/answer/332657604
参考博文6 从图(Graph)到图卷积(Graph Convolution):漫谈图神经网络模型 (系列)
https://www.cnblogs.com/SivilTaram/p/graph_neural_network_1.html
参考博文7 torch 默认参数初始化_torch_geometric 源码阅读(一)
https://blog.csdn.net/weixin_39946798/article/details/111640849

你可能感兴趣的:(机器学习,python,图神经网络,python,深度学习,算法,pytorch,神经网络)