CREATING MESSAGE PASSING NETWORKS

概述:邻居聚集和消息传递

The "MessagePassing" Base Class

PyG提供了消息传递基类,用于创建GNN自动化的消息传递机制。用户只需要定义函数 γ 和 ϕ \gamma 和 \phiγ和ϕ ,分别表示为message(),update()。聚集操作有aggr="add", aggr="mean" or aggr="max"等。

下面是一些相关方法的简介:

MessagePassing(aggr="add", flow="source_to_target", node_dim=-2):定义了一个聚集机制,三个参数分别聚集方式,消息传递方向以及沿哪个维度进行传播。

MessagePassing.propagate(edge_index, size=None, **kwargs):首次调用开始传播消息。获取边索引edge_index和所有用于构造消息和更新节点嵌入的附加数据。这个函数不但可以用于方阵,而且也可以用于二分图等非方阵图,但是需要传递size参数表明矩阵形状size=(N, M)。

MessagePassing.message(...):构造消息到节点i,但是根据传播方向有两个中情况,如果边方向是(j,i) 且 flow="source_to_target",即边是j指向i,而且消息流向是源节点到目的节点,或者相反。通常将中心节点表示为i,邻居节点表示为j。

MessagePassing.update(aggr_out, ...):更新每个节点i的嵌入向量,接受聚合的输出作为第一个参数以及最初传递给propagate()的任何参数。

GCN层实现

GCN层的数学定义如下:

 

邻居节点的特征首先通过权重矩阵Θ \mathbf{\Theta}Θ进行变换,然后使用他们的度进行标准化,最终加和。将其步骤写为如下几步:

向邻接矩阵添加自循环。

线性变换节点特征矩阵。

计算归一化系数。

Normalize节点特性。

对相邻节点特征进行归纳(add聚合)。

步骤1-3通常是在消息传递之前计算的。使用MessagePassing基类可以很容易地处理步骤4-5。全层实现如下所示:
 

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        """
        Args:
            in_channels 就是输入节点特征维度
            out_channels 节点输出特征的维度
        
        """
        super().__init__(aggr='add')  # "Add" 聚集操作 (Step 5).
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: 在邻接矩阵中添加自循环
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: 对节点的特征矩阵进行转换
        x = self.lin(x)

        # Step 3: 归一化
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype) # 计算节点的度
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: 消息传递
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

你可能感兴趣的:(论文研读,pytorch)