PyG利用MessagePassing搭建GCN实现节点分类

目录

  • 前言
  • 1. 数据处理
  • 2. MessagePassing
  • 3. GCN
    • 3.1 message
    • 3.2 aggregate
    • 3.3 update
    • 3.4 propagate
  • 4. GCN模型搭建
    • 4.1 模型训练
    • 4.2 模型测试
  • 5. 完整代码

前言

PyG和DGL是GNN领域的两大框架,两大框架的底层都是基于消息传递机制,即PyG中的MessagePassing基类和DGL中的Message Passing Paradigm。

关于DGL的消息传递范式,前面已经有几篇文章进行过讲解:

  1. 简单了解DGL中的数据格式
  2. 详解DGL中的消息传递API
  3. 利用DGL中的消息传递API手搭GCN实现节点分类

1. 数据处理

本篇文章使用Citeseer网络。Citeseer网络是一个引文网络,节点为论文,一共3327篇论文。论文一共分为六类:Agents、AI(人工智能)、DB(数据库)、IR(信息检索)、ML(机器语言)和HCI。如果两篇论文间存在引用关系,那么它们之间就存在链接关系。

dataset = Planetoid(root='data', name='CiteSeer')
dataset = dataset[0]
dataset.edge_index, _ = add_self_loops(dataset.edge_index)
dataset = dataset.to(device)
num_in_feats, num_out_feats = dataset.num_node_features, torch.max(dataset.y).item() + 1

2. MessagePassing

MessagePassing是PyG中定义的一个有关消息传递机制的基类,它通过自动处理消息传播来帮助创建此类消息传递图神经网络。用户只需定义消息函数message、聚合函数aggregate以及更新函数update,就能实现自定义GNN,这点和DGL类似。

消息传递的基本原理:
在这里插入图片描述
其中 x i ( k ) x_i^{(k)} xi(k)表示节点 i i i经过第 k k k层更新后的特征, e j , i e_{j,i} ej,i表示从节点 j j j到节点 i i i之间边的特征。

ϕ ( k ) \phi^{(k)} ϕ(k)表示第 k k k层的消息函数:例如可以将边特征和两个节点的特征求平均以得到新的特征。

□ \square 表示聚合函数:例如可以将节点 i i i的所有邻居节点经过消息函数处理后的特征进行求和,或者求和后再加上节点 i i i本身的特征 x i ( k − 1 ) x_i^{(k-1)} xi(k1)

γ ( k ) \gamma^(k) γ(k)表示第 k k k层的更新函数:例如可以将聚合后的特征经过简单的线性变换或者激活函数。

3. GCN

GCN的具体数学原理为:
在这里插入图片描述
对应到上面所讲的MessagePassing:

  1. message:将节点 i i i的所有邻居节点 j j j的特征进行一个简单的线性变换,然后将这些特征乘上一个权重 d e g ( i ) − 1 2 ⋅ d e g ( j ) − 1 2 deg(i)^{-\frac{1}{2}} \cdot deg(j)^{-\frac{1}{2}} deg(i)21deg(j)21,其中 d e g ( i ) deg(i) deg(i)表示节点 i i i添加自环后的度。
  2. aggregate:将节点 i i i所有邻居节点的经过消息函数处理后的特征求和。
  3. update:GCN中没有对应的update函数,不过上式中的update可以理解为加上一个bias。

3.1 message

message:将节点 i i i的所有邻居节点进行一个简单的线性变换,然后将这些特征乘上一个权重 d e g ( i ) − 1 2 ⋅ d e g ( j ) − 1 2 deg(i)^{-\frac{1}{2}} \cdot deg(j)^{-\frac{1}{2}} deg(i)21deg(j)21,其中 d e g ( i ) deg(i) deg(i)表示节点 i i i添加自环后的度。

因此,具体的代码实现为:

def message(self, x, edge_index):
    x = self.linear(x)
    row, col = edge_index
    deg = degree(col, x.size(0), dtype=x.dtype)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

    x_j = x[col]  # target nodes
    x_j = norm.view(-1, 1) * x_j  # 12431条边上target nodes的feature * norm

    return x_j

首先,我们将节点特征 x x x经过一个线性变换:

x = self.linear(x)

然后,得到所有节点的度的-0.5次方并计算乘积:

row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

其中,得到所有节点度的操作为:

row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)

其中col就是图中所有边中的目标节点,方法逻辑:统计col中从0到x.size(0)-1(节点数)中每个数出现的次数,该次数就是节点的度。注意,传入row时计算的实际上是出度,而传入col时计算的是入度,对于本文中使用的无向图来讲,二者计算结果一致。

得到所有节点的度后,计算每条边上源节点和目标节点度的-0.5次方的乘积以得到权重:

deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

最后,我们需要将邻居节点(目标节点)的特征乘上该权重并返回:

x_j = x[col]  # target nodes
x_j = norm.view(-1, 1) * x_j  # 12431条边上target nodes的feature * norm

return x_j

3.2 aggregate

aggregate:将节点 i i i所有邻居节点 x j x_j xj的经过消息函数处理后的特征求和。

def aggregate(self, x_j, edge_index):
    # x_j为target nodes的归一化特征
    row, col = edge_index
    # row(12431), x_j(12431, out_channels)
    out = scatter(x_j, row, dim=0, reduce='sum')
    return out

这里使用了torch_scatter中的scatter方法来对所有邻居节点的特征进行聚合。具体来讲,就是根据row中相同索引对应的 x j x_j xj中的元素进行求和处理,然后按照索引进行排序,其中 x j x_j xj为前面消息函数求得的所有边中目标节点的加权特征。

比如一共12431条边,那么row就是12431条边中源节点的索引值,假设索引0一共出现在了5个位置(节点0的出度为5),那么最终得到的out的第一个元素就是将 x j x_j xj中这5个位置的特征求和,也就是节点0的5个邻居节点的特征求和。

这样,经过scatter方法处理后,我们就得到了所有节点的更新后的特征值。

关于torch_scatter.scatter()的具体使用方法可以参考:torch_scatter.scatter()的使用方法详解。

3.3 update

观察PyG中对GCN的定义:
在这里插入图片描述
因此,我们可以将update简单理解为加上一个bias,即:

def update(self, out):
    return out + self.bias

3.4 propagate

在PyG中,MessagePassing通过调用propagate方法来实现图上的一次卷积操作,即前面提到的message、aggregate以及update操作:

def propagate(self, x, edge_index):
    out = self.message(x, edge_index)
    out = self.aggregate(out, edge_index)
    out = self.update(out)

    return out

因此,一个完整的GCNConv搭建如下:

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')
        self.linear = nn.Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.Tensor(out_channels))

    def message(self, x, edge_index):
        x = self.linear(x)
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        x_j = x[col]  # target nodes
        x_j = norm.view(-1, 1) * x_j   # 12431条边上target nodes的feature * norm

        return x_j

    def aggregate(self, x_j, edge_index):
        # x_j为target nodes的归一化特征
        row, col = edge_index
        # row(12431), x_j(12431, out_channels)
        out = scatter(x_j, row, dim=0, reduce='sum')
        return out

    def update(self, out):
        return out + self.bias

    def propagate(self, x, edge_index):
        out = self.message(x, edge_index)
        out = self.aggregate(out, edge_index)
        out = self.update(out)

        return out

    def forward(self, x, edge_index):
        return self.propagate(x, edge_index)

4. GCN模型搭建

一个简单的两层GCN搭建如下:

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 32)
        self.conv2 = GCNConv(32, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = F.softmax(x, dim=1)

        return x

4.1 模型训练

训练:

def train(model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
    loss_function = torch.nn.CrossEntropyLoss().to(device)
    model.train()
    min_epochs = 10
    best_model = None
    min_val_loss = 5
    for epoch in range(200):
        out = model(dataset)
        optimizer.zero_grad()
        loss = loss_function(out[dataset.train_mask], dataset.y[dataset.train_mask])
        loss.backward()
        optimizer.step()
        # validation
        val_loss = get_val_loss(model)
        if epoch + 1 >= min_epochs and val_loss < min_val_loss:
            min_val_loss = val_loss
            best_model = copy.deepcopy(model)
        print('Epoch: {:3d} train_Loss: {:.5f} val_loss: {:.5f}'.format(epoch, loss.item(), val_loss))
        model.train()

    return best_model

4.2 模型测试

测试:

def test(model):
    model.eval()
    _, pred = model(dataset).max(dim=1)
    correct = int(pred[dataset.test_mask].eq(dataset.y[dataset.test_mask]).sum().item())
    acc = correct / int(dataset.test_mask.sum())
    print('GCN Accuracy: {:.4f}'.format(acc))

实验结果:69.8%的准确率。

5. 完整代码

代码地址:GNNs-for-Node-Classification。原创不易,下载时请给个follow和star!感谢!!

你可能感兴趣的:(PyG,GNN,PyG,MessagePassing,节点分类)