对于一个具有编码 h v l − 1 h_v^{l-1} hvl−1的中心节点 v v v,进行下一步状态更新的规则为:
h v ( l ) = W l ⋅ h v ( l − 1 ) + W r ⋅ A G G ( { h u ( l − 1 ) , ∀ u ∈ N ( v ) } ) h_v^{(l)} = W_l\cdot h_v^{(l-1)} + W_r \cdot AGG(\{h_u^{(l-1)}, \forall u \in N(v) \}) hv(l)=Wl⋅hv(l−1)+Wr⋅AGG({hu(l−1),∀u∈N(v)})
W l W_l Wl 和 W r W_r Wr为可学习的权重, N ( v ) N(v) N(v) 代表 v v v的邻接节点。 A G G ( ⋅ ) AGG(·) AGG(⋅) 为消息聚合函数,当采用 mean aggregation时,有
A G G ( { h u ( l − 1 ) , ∀ u ∈ N ( v ) } ) = 1 ∣ N ( v ) ∣ ∑ u ∈ N ( v ) h u ( l − 1 ) AGG(\{h_u^{(l-1)}, \forall u \in N(v) \}) = \frac{1}{|N(v)|} \sum_{u\in N(v)} h_u^{(l-1)} AGG({hu(l−1),∀u∈N(v)})=∣N(v)∣1u∈N(v)∑hu(l−1)
实现分三步,分别为
1)每一个邻居 u u u节点传递当前状态 u l − 1 u^{l-1} ul−1;
2)中心节点 v v v 使用聚合函数聚合收到的消息,在GraphSage中为简单求平均;
3)中心节点使用聚合消息更新自己的状态,在GraphSage中为残差。
pytorch
提供了MessagePassing
父类,我们借此可以简洁实现消息传递。
class GraphSage(MessagePassing):
def __init__(self, in_channels, out_channels, normalize = True,
bias = False, **kwargs):
super(GraphSage, self).__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize
self.lin_l=nn.Linear(in_features=in_channels, out_features=out_channels)
self.lin_r=nn.Linear(in_features=in_channels, out_features=out_channels)
def message(self, x_j):
out = None
out = self.lin_r(x_j)
return out
def aggregate(self, inputs, index, dim_size = None):
out = None
node_dim = self.node_dim
out=torch_scatter.scatter(inputs, index, dim=node_dim,reduce='mean')
return out
def forward(self, x, edge_index, size = None):
out=self.propagate(edge_index,x=(x,x))
out=self.lin_l(x)+out
if self.normalize:
out=F.normalize(out)
return out
①message
函数定义全局消息传递的内容。参数x_j
描述所有消息传递关系中源节点的特征,形状为 [ ∣ E ∣ , d ] [|\mathcal{E}|, d] [∣E∣,d], ( i , j ) ∈ E (i, j) \in \mathcal{E} (i,j)∈E.
②aggregate
函数定义了中心节点接收和聚合消息的方法。参数inputs
是message
函数的返回值,index
描述了每个中心节点 v v v接收来自邻居节点 u u u的消息在inputs
的哪一行行。scatter
函数声明为
torch_scatter.scatter(input: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, dim_size: Optional[int] = None, reduce: str = 'sum')→ Tensor[source]
函数功能为用index
在dim
指定的维度索引张量input
,再根据reduce
规则计算返回值。
如图所示,中心节点0的邻居节点在input
的第0、1、3个索引。
③propagate
函数定义在MessagePassing
父类。用于启动一次消息传递过程。edge_index
为整张图的边索引信息,形状是 [ 2 , E ] [2,\mathcal{E}] [2,E]。参数x
存放邻居节点和中心节点的特征。因为每个节点既是中心节点又是邻居节点,且采用一样的特征描述,所以元组的两个元素是一样的。propagate
函数会自动调用message
和aggregate
完成消息传递和消息聚合。
④当GraphSage
对象被调用时,默认调用forward
来启动消息传递。forward
函数返回更新后的节点特征张量,形状为 [ ∣ N ∣ , d ] [|N|, d] [∣N∣,d]. N N N是所有节点的集合。
使用CORA dataset数据集进行节点分类任务。训练过程如下