式中 口 口 口表示可微的排列不变函数,例如: sum, mean, min, max 或者 mul 和
γ Θ \gamma_{\Theta} γΘ和 ϕ Θ \phi_{\Theta} ϕΘ表示可微函数,例如 MLPs. 请参见这里的附带教程。
参数
方法
propagate(edge_index: Union[Tensor, SparseTensor], size: Optional[Tuple[int, int]] = None, **kwargs)
开始传播消息的初始调用。
edge_updater(edge_index: Union[Tensor, SparseTensor], **kwargs)
message(x_j: Tensor)
构造从节点j
到节点i
的消息,类似于边索引中edge_index的 ϕ Θ \phi_{\Theta} ϕΘ。此函数可以将任何参数作为输入,最初传递 propagate()。此外,通过将_i
或_j
附加到变量名称.比如x_i
和x_j
。
aggregate(inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None)
聚合来自邻居的消息为 □ j ∈ N ( i ) \square_{j \in \mathcal{N}(i)} □j∈N(i)
接受消息计算的输出作为第一个参数和最初传递给propagate()的任何参数。
默认情况下,该函数将把它的调用委托给底层的Aggregation模块,通过aggr减少 init()中指定的消息参数。
message_and_aggregate(adj_t: Union[SparseTensor, Tensor])
将message()
和aggregate()
的计算融合为一个函数。如果适用,这将节省时间和内存,因为消息不需要显式地物化。只有在实现它的情况下,此功能才会被调用,并且基于torch_sparse.SparseTensor
或torch.sparse.Tensor
进行传播。
update(inputs: Tensor)
为每个节点 i ∈ V i \in V i∈V更新节点嵌入,类似于 γ Θ \gamma_{\Theta} γΘ。接受聚合的输出作为第一个参数和最初传递给propagate()的任何参数。
edge_update()
计算或更新图中每条边的特征。此函数可以接受最初传递给edge_updater()
的任何参数作为输入, 此外,传递给edge updater()
的张量可以通过在变量名后附加i
或j
来映射到各自的节点_i
和_j
。比如:x_i
和x_j
。
register_propagate_forward_pre_hook(hook: Callable)
在模块上注册一个前向预钩子。每次propagate()
调用之前都会调用钩子。
register_propagate_forward_hook(hook: Callable)
钩子可以修改输入。输入关键字参数作为inputs[-1]中的字典传递给钩子。
返回一个torch.utils.hooks.RemovableHandle
,可用于通过调用handle.remove()
删除添加的钩子。
register_message_forward_pre_hook(hook: Callable)
register_message_forward_hook(hook: Callable)
register_aggregate_forward_pre_hook(hook: Callable)
register_aggregate_forward_hook(hook: Callable)
register_message_and_aggregate_forward_pre_hook(hook: Callable)
register_message_and_aggregate_forward_hook(hook: Callable)
register_edge_update_forward_pre_hook(hook: Callable)
register_edge_update_forward_hook(hook: Callable)
jittable(typing: Optional[str] = None)
图卷积算子来自 “Semi-supervised Classification with Graph Convolutional Networks” 论文
X ′ = D ^ − 1 / 2 A ^ D ^ − 1 / 2 X Θ , \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, X′=D^−1/2A^D^−1/2XΘ,
其中, A ^ = A + I \mathbf{\hat{A}} = \mathbf{A} + \mathbf{I} A^=A+I表示插入自循环和的邻接矩阵,并且 D ^ i i = ∑ j = 0 A ^ i j \hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij} D^ii=∑j=0A^ij它的对角度矩阵。邻接矩阵可以包括除1以外的其他值,通过可选的edge_weight
张量表示边权。
其节点式由
x i ′ = Θ ⊤ ∑ j ∈ N ( v ) ∪ { i } e j , i d ^ j d ^ i x j \mathbf{x}^{\prime}_i = \mathbf{\Theta}^{\top} \sum_{j \in \mathcal{N}(v) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j xi′=Θ⊤j∈N(v)∪{i}∑d^jd^iej,ixj
关于 d ^ i = 1 + ∑ j ∈ N ( i ) e j , i \hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i} d^i=1+∑j∈N(i)ej,i。其中 e j , i e_{j,i} ej,i
表示从源节点 j j j到目标节点 i i i的边权值(默认值:1.0)
参数
forward
forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None)→ Tensor[source]
GraphSAGE 论文来自“Inductive Representation Learning on Large Graphs”
x i ′ = W 1 x i + W 2 ⋅ m e a n j ∈ N ( i ) x j \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j xi′=W1xi+W2⋅meanj∈N(i)xj
如果project = True,那么 x j x_j xj将首先通过投影
x j ← σ ( W 3 x j + b ) \mathbf{x}_j \leftarrow \sigma ( \mathbf{W}_3 \mathbf{x}_j + \mathbf{b}) xj←σ(W3xj+b)
The chebyshev spectral graph convolutional operator from the “Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering” paper
The GraphSAGE operator from the “Inductive Representation Learning on Large Graphs” paper
The graph neural network operator from the “Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks” paper
The GravNet operator from the “Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks” paper, where the graph is dynamically constructed using nearest neighbors.
The gated graph convolution operator from the “Gated Graph Sequence Neural Networks” paper
The residual gated graph convolutional operator from the “Residual Gated Graph ConvNets” paper
The graph attentional operator from the “Graph Attention Networks” paper
The fused graph attention operator from the “Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective” paper.
The GATv2 operator from the “How Attentive are Graph Attention Networks?” paper, which fixes the static attention problem of the standard GATConv layer.
The graph transformer operator from the “Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification” paper
The graph attentional propagation layer from the “Attention-based Graph Neural Network for Semi-Supervised Learning” paper
The topology adaptive graph convolutional networks operator from the “Topology Adaptive Graph Convolutional Networks” paper
The graph isomorphism operator from the “How Powerful are Graph Neural Networks?” paper
The modified GINConv operator from the “Strategies for Pre-training Graph Neural Networks” paper
The ARMA graph convolutional operator from the “Graph Neural Networks with Convolutional ARMA Filters” paper
The simple graph convolutional operator from the “Simplifying Graph Convolutional Networks” paper
The simple spectral graph convolutional operator from the “Simple Spectral Graph Convolution” paper
The approximate personalized propagation of neural predictions layer from the “Predict then Propagate: Graph Neural Networks meet Personalized PageRank” paper
The graph neural network operator from the “Convolutional Networks on Graphs for Learning Molecular Fingerprints” paper
The relational graph convolutional operator from the “Modeling Relational Data with Graph Convolutional Networks” paper
The relational graph attentional operator from the “Relational Graph Attention Networks” paper.
The signed graph convolutional operator from the “Signed Graph Convolutional Network” paper
The dynamic neighborhood aggregation operator from the “Just Jump: Towards Dynamic Neighborhood Aggregation in Graph Neural Networks” paper
The PointNet set layer from the “PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation” and “PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space” papers
alias of PointNetConv
The gaussian mixture model convolutional operator from the “Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs” paper
The spline-based convolutional operator from the “SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels” paper
The continuous kernel-based convolutional operator from the “Neural Message Passing for Quantum Chemistry” paper.
alias of NNConv
The crystal graph convolutional operator from the “Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties” paper
The edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper
The dynamic edge convolutional operator from the “Dynamic Graph CNN for Learning on Point Clouds” paper (see torch_geometric.nn.conv.EdgeConv), where the graph is dynamically constructed using nearest neighbors in the feature space.
The convolutional operator on
-transformed points from the “PointCNN: Convolution On X-Transformed Points” paper
The PPFNet operator from the “PPFNet: Global Context Aware Local Features for Robust 3D Point Matching” paper
The (translation-invariant) feature-steered convolutional operator from the “FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis” paper
The Point Transformer layer from the “Point Transformer” paper
The hypergraph convolutional operator from the “Hypergraph Convolution and Hypergraph Attention” paper
The local extremum graph neural network operator from the “ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations” paper, which finds the importance of nodes with respect to their neighbors using the difference operator:
The Principal Neighbourhood Aggregation graph convolution operator from the “Principal Neighbourhood Aggregation for Graph Nets” paper
The ClusterGCN graph convolutional operator from the “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks” paper
The GENeralized Graph Convolution (GENConv) from the “DeeperGCN: All You Need to Train Deeper GCNs” paper.
The graph convolutional operator with initial residual connections and identity mapping (GCNII) from the “Simple and Deep Graph Convolutional Networks” paper
The path integral based convolutional operator from the “Path Integral Based Convolution and Pooling for Graph Neural Networks” paper
The Weisfeiler Lehman operator from the “A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction” paper, which iteratively refines node colorings:
The Weisfeiler Lehman operator from the “Wasserstein Weisfeiler-Lehman Graph Kernels” paper.
The FiLM graph convolutional operator from the “GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation” paper
The self-supervised graph attentional operator from the “How to Find Your Friendly Neighborhood: Graph Attention Design with Self-Supervision” paper
The Frequency Adaptive Graph Convolution operator from the “Beyond Low-Frequency Information in Graph Convolutional Networks” paper
The Efficient Graph Convolution from the “Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions” paper.
The pathfinder discovery network convolutional operator from the “Pathfinder Discovery Networks for Neural Message Passing” paper
A general GNN layer adapted from the “Design Space for Graph Neural Networks” paper.
The Heterogeneous Graph Transformer (HGT) operator from the “Heterogeneous Graph Transformer” paper.
The heterogeneous edge-enhanced graph attentional operator from the “Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction” paper, which enhances GATConv by:
A generic wrapper for computing graph convolution on heterogeneous graphs.
The Heterogenous Graph Attention Operator from the “Heterogenous Graph Attention Network” paper.
The Light Graph Convolution (LGC) operator from the “LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation” paper
The PointGNN operator from the “Point-GNN: Graph Neural Network for 3D Object Detection in a Point Cloud” paper
The general, powerful, scalable (GPS) graph transformer layer from the “Recipe for a General, Powerful, Scalable Graph Transformer” paper.
torch.nn.Sequential
容器的扩展,用于定义顺序GNN模型。由于GNN操作符接受多个输入参数,因此torch geometry.nn.sequential
既需要全局输入参数,也需要单个操作符的函数头定义。如果省略,中间模块将对前一个模块的输出进行操作
from torch.nn import Linear, ReLU
from torch_geometric.nn import Sequential, GCNConv
model = Sequential('x, edge_index', [
(GCNConv(in_channels, 64), 'x, edge_index -> x'),
ReLU(inplace=True),
(GCNConv(64, 64), 'x, edge_index -> x'),
ReLU(inplace=True),
Linear(64, out_channels),
])
其中’x, edge index’定义了模型的输入参数,'x, edge index -> x’定义了GCNConv的函数头,即输入参数和返回类型。
特别是,这还允许创建更复杂的模型,例如使用JumpingKnowledge
from torch.nn import Linear, ReLU, Dropout
from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge
from torch_geometric.nn import global_mean_pool
model = Sequential('x, edge_index, batch', [
(Dropout(p=0.5), 'x -> x'),
(GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'),
ReLU(inplace=True),
(GCNConv(64, 64), 'x1, edge_index -> x2'),
ReLU(inplace=True),
(lambda x1, x2: [x1, x2], 'x1, x2 -> xs'),
(JumpingKnowledge("cat", 64, num_layers=2), 'xs -> x'),
(global_mean_pool, 'x, batch -> x'),
Linear(2 * 64, dataset.num_classes),
])
对传入数据应用线性转换
类似torch.nn.Linear
。它支持延迟初始化和可定制的权重和偏差初始化。
torch.nn.Linear
的默认权重初始化。(默认值:None)
根据类型对传入数据应用单独的线性转换
对于类型k, 它支持延迟初始化和可定制的权重和偏差初始化。