利用MessagePassing实现GraphSAGE(了解pyG的底层逻辑)

一些基础的理解参考(必看前面的必要了解

pyG利用MessagePassing实现GCN(了解pyG的底层逻辑)_山、、、的博客-CSDN博客pyG利用MessagePassing实现GCNhttps://blog.csdn.net/qq_44689178/article/details/123736686

注: 如果看明白上面GCN的实现,下面这个可以不看,很简单。

下面解析如何按照GraphSAGE公式,利用MessagePassing实现SAGEConv层。

下面是官方文档:

torch_geometric.nn — pytorch_geometric 2.0.5 documentationicon-default.png?t=M276https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.message_passing.MessagePassing.propagate

首先,参考下面链接里面graphSAGE的公式

图神经网络(7)- GNN通用框架_山、、、的博客-CSDN博客介绍GNN通用框架,从GCN到GraphSAGE,GAThttps://blog.csdn.net/qq_44689178/article/details/123334408

具体来讲,我们使用下面这个公式

 此处AGG,为了简化我们选择使用求平均。

 那么在代码计算节点message和aggregate时就是使用下面这个公式:

 

第一个函数 __init__,定义了一些参数。其中lin_l和lin_r分别对应上面公式中的W1,W2

class GraphSage(MessagePassing):
    
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):  
        super(GraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_l = torch.nn.Linear(in_channels,out_channels) #W_l,对中心节点应用
        self.lin_r = torch.nn.Linear(in_channels,out_channels) #W_r,对邻居节点应用

        
        
        self.reset_parameters()

 下面的 reset_parameters(self):函数略过。主要解释forward函数。在forward函数里面,首先它调用了propagate函数,这个函数就会去调用 message函数和aggregate函数,按照我的理解,先调用message函数,message函数的输入会被当作aggregate函数的输入!

此处的propagate函数完成了,也就是对邻居节点特征求和的操作。

之后, x = self.lin_l(x)  完成了

 out = self.lin_r(out)就是完成了

 然后加起来,就完成了SAGE网络的一次卷积计算

 代码后面还对结果,做了标准化。

def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

    def forward(self, x, edge_index, size = None):
        """"""

        # message-passing + post-processing
        out=self.propagate(edge_index,x=(x,x),size=size)  #message passing
        # propagate方法前期准备所用到的参数,
        #后期依次调用self.message、self.aggregate和self.update方法。
       
        
        x = self.lin_l(x)   #  前一层的本身,做一次WH(线性变换)
        out = self.lin_r(out) # 对求完平均的邻居节点,做一次线性变换
        out = out + x
        if self.normalize:  #L2
            out=F.normalize(out)
            
        return out

后面的操作非常简单了,后面message和aggregate操作和实现GCN一模一样!不多说

    def message(self, x_j):
        # x_j 表示邻居节点

        out = x_j
        # 参考 GCN 例子 没什么message操作
        return out

    def aggregate(self, inputs, index, dim_size = None):

        out = None

        # The axis along which to index number of nodes.
        node_dim = self.node_dim
       
        #这个地方要实现,把邻居节点信息求平均的操作
        out = scatter(inputs, index, dim=node_dim, dim_size=dim_size, reduce='mean')
    
        
        return out

你可能感兴趣的:(GNN,神经网络,pytorch)