一些基础的理解参考(必看前面的必要了解)
pyG利用MessagePassing实现GCN(了解pyG的底层逻辑)_山、、、的博客-CSDN博客pyG利用MessagePassing实现GCNhttps://blog.csdn.net/qq_44689178/article/details/123736686
注: 如果看明白上面GCN的实现,下面这个可以不看,很简单。
下面解析如何按照GraphSAGE公式,利用MessagePassing实现SAGEConv层。
下面是官方文档:
torch_geometric.nn — pytorch_geometric 2.0.5 documentationhttps://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.message_passing.MessagePassing.propagate
首先,参考下面链接里面graphSAGE的公式
图神经网络(7)- GNN通用框架_山、、、的博客-CSDN博客介绍GNN通用框架,从GCN到GraphSAGE,GAThttps://blog.csdn.net/qq_44689178/article/details/123334408
具体来讲,我们使用下面这个公式
此处AGG,为了简化我们选择使用求平均。
那么在代码计算节点message和aggregate时就是使用下面这个公式:
第一个函数 __init__,定义了一些参数。其中lin_l和lin_r分别对应上面公式中的W1,W2
class GraphSage(MessagePassing):
def __init__(self, in_channels, out_channels, normalize = True,
bias = False, **kwargs):
super(GraphSage, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
self.lin_l = torch.nn.Linear(in_channels,out_channels) #W_l,对中心节点应用
self.lin_r = torch.nn.Linear(in_channels,out_channels) #W_r,对邻居节点应用
self.reset_parameters()
下面的 reset_parameters(self):函数略过。主要解释forward函数。在forward函数里面,首先它调用了propagate函数,这个函数就会去调用 message函数和aggregate函数,按照我的理解,先调用message函数,message函数的输入会被当作aggregate函数的输入!
此处的propagate函数完成了,也就是对邻居节点特征求和的操作。
然后加起来,就完成了SAGE网络的一次卷积计算
代码后面还对结果,做了标准化。
def reset_parameters(self):
self.lin_l.reset_parameters()
self.lin_r.reset_parameters()
def forward(self, x, edge_index, size = None):
""""""
# message-passing + post-processing
out=self.propagate(edge_index,x=(x,x),size=size) #message passing
# propagate方法前期准备所用到的参数,
#后期依次调用self.message、self.aggregate和self.update方法。
x = self.lin_l(x) # 前一层的本身,做一次WH(线性变换)
out = self.lin_r(out) # 对求完平均的邻居节点,做一次线性变换
out = out + x
if self.normalize: #L2
out=F.normalize(out)
return out
后面的操作非常简单了,后面message和aggregate操作和实现GCN一模一样!不多说
def message(self, x_j):
# x_j 表示邻居节点
out = x_j
# 参考 GCN 例子 没什么message操作
return out
def aggregate(self, inputs, index, dim_size = None):
out = None
# The axis along which to index number of nodes.
node_dim = self.node_dim
#这个地方要实现,把邻居节点信息求平均的操作
out = scatter(inputs, index, dim=node_dim, dim_size=dim_size, reduce='mean')
return out