Torch geometric GatedGraphConv 源码分析

Torch geometric GatedGraphConv 源码分析

  • 相关论文
  • 公式
  • GatedGraphConv源码

相关论文

Gated Graph Sequence Neural Networks

公式

h i ( 0 ) = x i   ∥   0 (1) \mathbf{h}_i^{(0)}= \mathbf{x}_i \, \Vert \, \mathbf{0} \tag{1} hi(0)=xi0(1) m i ( l + 1 ) = ∑ j ∈ N ( i ) Θ ⋅ h j ( l ) (2) \\\mathbf{m}_i^{(l+1)} = \sum_{j \in \mathcal{N}(i)} \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)} \tag{2} mi(l+1)=jN(i)Θhj(l)(2) h i ( l + 1 ) = GRU ( m i ( l + 1 ) , h i ( l ) ) (3) \\\mathbf{h}_i^{(l+1)} = \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)}) \tag{3} hi(l+1)=GRU(mi(l+1),hi(l))(3)
其中,第(1)个公式中, h i ( 0 ) \mathbf{h}_i^{(0)} hi(0)是输入状态, x i \mathbf{x}_i xi是节点i的特征,后面加0是一种padding操作,即把特征扩充到指定维度上。第(2)个公式中, Θ \mathbf{\Theta} Θ是待学习的参数矩阵,这里是在汇聚周围节点的信息,第(3)个公式就是一个GRU单元,把上面两个公式得到的作为输入,得到一个输出,该输出可以作为节点i的新特征。

GatedGraphConv源码

class GatedGraphConv(out_channels, num_layers, aggr='add', bias=True, **kwargs)

  • out_channels (int) – Size of each input sample. 输出维度(输入维度不能大于输出维度)
  • num_layers (int) – The sequence length L. GRU单元的层数
  • aggr (string, optional) – The aggregation scheme to use (“add”, “mean”, “max”). (default: “add”)公式(2)中的聚合方式,默认是累加
  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)公式(2)中聚合后是否加一个偏置项,默认True
  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.
import torch
from torch import Tensor
from torch.nn import Parameter as Param
from torch_geometric.nn.conv import MessagePassing

from ..inits import uniform


class GatedGraphConv(MessagePassing):


    def __init__(self,
                 out_channels,
                 num_layers,
                 aggr='add',
                 bias=True,
                 **kwargs):
        super(GatedGraphConv, self).__init__(aggr=aggr, **kwargs)

        self.out_channels = out_channels
        self.num_layers = num_layers

        self.weight = Param(Tensor(num_layers, out_channels, out_channels))
        self.rnn = torch.nn.GRUCell(out_channels, out_channels, bias=bias)

        self.reset_parameters()

    def reset_parameters(self):
        uniform(self.out_channels, self.weight)
        self.rnn.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None):
        """"""
        h = x if x.dim() == 2 else x.unsqueeze(-1)
        if h.size(1) > self.out_channels:
            raise ValueError('The number of input channels is not allowed to '
                             'be larger than the number of output channels')

        if h.size(1) < self.out_channels:
            zero = h.new_zeros(h.size(0), self.out_channels - h.size(1))
            h = torch.cat([h, zero], dim=1)

        for i in range(self.num_layers):
            m = torch.matmul(h, self.weight[i])
            m = self.propagate(edge_index, x=m, edge_weight=edge_weight)
            h = self.rnn(m, h)

        return h

    def message(self, x_j, edge_weight):
        if edge_weight is not None:
            return edge_weight.view(-1, 1) * x_j
        return x_j

你可能感兴趣的:(图神经网络,GNN)