CS224W - Colab 3 MessagePassing实现GraphSAGE

Implement the GraphSAGE layer directly

1.GraphSage

对于一个具有编码 h v l − 1 h_v^{l-1} hvl1的中心节点 v v v,进行下一步状态更新的规则为:
h v ( l ) = W l ⋅ h v ( l − 1 ) + W r ⋅ A G G ( { h u ( l − 1 ) , ∀ u ∈ N ( v ) } ) h_v^{(l)} = W_l\cdot h_v^{(l-1)} + W_r \cdot AGG(\{h_u^{(l-1)}, \forall u \in N(v) \}) hv(l)=Wlhv(l1)+WrAGG({hu(l1),uN(v)})
W l W_l Wl W r W_r Wr为可学习的权重, N ( v ) N(v) N(v) 代表 v v v的邻接节点。 A G G ( ⋅ ) AGG(·) AGG() 为消息聚合函数,当采用 mean aggregation时,有
A G G ( { h u ( l − 1 ) , ∀ u ∈ N ( v ) } ) = 1 ∣ N ( v ) ∣ ∑ u ∈ N ( v ) h u ( l − 1 ) AGG(\{h_u^{(l-1)}, \forall u \in N(v) \}) = \frac{1}{|N(v)|} \sum_{u\in N(v)} h_u^{(l-1)} AGG({hu(l1),uN(v)})=N(v)1uN(v)hu(l1)

2.Implement

(1)实现方法

实现分三步,分别为

1)每一个邻居 u u u节点传递当前状态 u l − 1 u^{l-1} ul1

2)中心节点 v v v 使用聚合函数聚合收到的消息,在GraphSage中为简单求平均;

3)中心节点使用聚合消息更新自己的状态,在GraphSage中为残差。

(2)实现步骤

pytorch提供了MessagePassing父类,我们借此可以简洁实现消息传递。

class GraphSage(MessagePassing):
    
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):  
        super(GraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_l=nn.Linear(in_features=in_channels, out_features=out_channels)
        self.lin_r=nn.Linear(in_features=in_channels, out_features=out_channels)
    
    def message(self, x_j):
        out = None
        out = self.lin_r(x_j)
        return out
   
    def aggregate(self, inputs, index, dim_size = None):
        out = None
        node_dim = self.node_dim
        out=torch_scatter.scatter(inputs, index, dim=node_dim,reduce='mean')
        return out
 
    def forward(self, x, edge_index, size = None):
        out=self.propagate(edge_index,x=(x,x))
        out=self.lin_l(x)+out
        if self.normalize:
            out=F.normalize(out)
        return out

message函数定义全局消息传递的内容。参数x_j描述所有消息传递关系中源节点的特征,形状为 [ ∣ E ∣ , d ] [|\mathcal{E}|, d] [E,d] ( i , j ) ∈ E (i, j) \in \mathcal{E} (i,j)E.

aggregate函数定义了中心节点接收和聚合消息的方法。参数inputsmessage函数的返回值,index描述了每个中心节点 v v v接收来自邻居节点 u u u的消息在inputs的哪一行行。scatter函数声明为

torch_scatter.scatter(input: Tensor, index: Tensor, dim: int = -1, out: Optional[Tensor] = None, dim_size: Optional[int] = None, reduce: str = 'sum')→ Tensor[source]

函数功能为用indexdim指定的维度索引张量input,再根据reduce规则计算返回值。

CS224W - Colab 3 MessagePassing实现GraphSAGE_第1张图片

如图所示,中心节点0的邻居节点在input的第0、1、3个索引。

propagate函数定义在MessagePassing父类。用于启动一次消息传递过程。edge_index为整张图的边索引信息,形状是 [ 2 , E ] [2,\mathcal{E}] [2,E]。参数x存放邻居节点和中心节点的特征。因为每个节点既是中心节点又是邻居节点,且采用一样的特征描述,所以元组的两个元素是一样的。propagate函数会自动调用messageaggregate完成消息传递和消息聚合。

④当GraphSage对象被调用时,默认调用forward来启动消息传递。forward函数返回更新后的节点特征张量,形状为 [ ∣ N ∣ , d ] [|N|, d] [N,d]. N N N是所有节点的集合。

3. Train and Test

使用CORA dataset数据集进行节点分类任务。训练过程如下

CS224W - Colab 3 MessagePassing实现GraphSAGE_第2张图片

你可能感兴趣的:(python,深度学习,pytorch,人工智能)