使用PyG (PyTorch Geometric) 实现同质图transductive链路预测任务

诸神缄默不语-个人CSDN博文目录

本文代码参考自PyG官方示例代码:https://github.com/pyg-team/pytorch_geometric/blob/master/examples/link_pred.py

文章目录

  • 1. 数据获取
  • 2. 数据预处理
  • 3. 建立链路预测模型
  • 4. 实例化模型,设置优化器、损失函数
  • 5. 构建训练函数
  • 6. 构建每个epoch运行时的测试函数
  • 7. 训练和测试
  • 8. 整体代码

1. 数据获取

本文直接调用PyG官方的Cora数据集,如果环境可以直接登外网的话,其实可以直接运行后续模型。如果不能的话,可以参考我之前撰写的博文手动下载对应数据:PyG的Planetoid无法直接下载Cora等数据集的3个解决方式

2. 数据预处理

这里的处理方式是直接在载入数据时,就直接调用PyG的类:

  1. 对节点特征进行行归一化(T.NormalizeFeatures(),文档https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.NormalizeFeatures,源码torch_geometric.transforms.normalize_features — pytorch_geometric documentation):使每一行总和为1、且更稀疏,具体做法是:元素减去最小值,然后除以总值(设置最小值为1)
  2. 将DataSet对象放到GPU上(T.ToDevice(device),文档https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.ToDevice)
  3. 对DataSet对象用链路预测的方法进行数据集划分:
    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=False)
    文档:https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.RandomLinkSplit
    训练集中不包含验证集和测试集的边,验证集中不包含测试集的边。注意本代码是transductive的,所以划分得到的3个数据集
    返回的DataSet对象中的元素是tuple,每个tuple包含3个元素(train_data/val_data/test_data),每个元素都是Data对象。
import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
                      add_negative_train_samples=False),
])
dataset = Planetoid('pyg_data/Planetoid', name='Cora', transform=transform)
print(type(dataset))
train_data, val_data, test_data = dataset[0]
print(type(train_data))

输出(由于SciPy包版本导致的警告不赘):



3. 建立链路预测模型

  1. encode()函数:GNN节点表征,使用2层GCN,其中用了ReLU激活函数。没有其他trick。
  2. decode()函数在训练时使用,仅计算指定edge_label_index上的边,在代码上用逐元素求和表示点积。
  3. decode_all()函数在测试时使用,计算整张图所有节点对存在边的概率,也是用矩阵乘法来实现点积,结果的概率大于0直接认为节点对之间存在边,返回的是这个被认为存在边的edge list。
import torch
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

4. 实例化模型,设置优化器、损失函数

链路预测一般被建模为二分类任务(即边是否存在,因此常用torch.nn.BCEWithLogitsLoss()

model = Net(dataset.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

5. 构建训练函数

每个epoch调用一次训练函数。
在训练集上,首先用GNN实现节点表征,然后调用negative_sampling(文档:https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.negative_sampling)抽样负边(与正边数量一样),计算对应的损失函数。

from torch_geometric.utils import negative_sampling

def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)

    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
        num_neg_samples=train_data.edge_label_index.size(1), method='sparse')

    edge_label_index = torch.cat(
        [train_data.edge_label_index, neg_edge_index],
        dim=-1,
    )
    edge_label = torch.cat([
        train_data.edge_label,
        train_data.edge_label.new_zeros(neg_edge_index.size(1))
    ], dim=0)

    out = model.decode(z, edge_label_index).view(-1)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
    return loss

6. 构建每个epoch运行时的测试函数

我个人比较喜欢用with torch.no_grad()
每个epoch调用一次。
计算图数据上正边的概率,直接用其通过Sigmoid激活函数后的结果作为边存在的概率,用以计算ROC AUC值。

@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())

7. 训练和测试

训练100个epoch,最后得到测试集上所有模型认为存在的边。

best_val_auc = final_test_auc = 0
for epoch in range(1, 101):
    loss = train()
    val_auc = test(val_data)
    test_auc = test(test_data)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
          f'Test: {test_auc:.4f}')

print(f'Final Test: {final_test_auc:.4f}')

z = model.encode(test_data.x, test_data.edge_index)
final_edge_index = model.decode_all(z)

输出:

Epoch: 001, Loss: 0.6930, Val: 0.6729, Test: 0.7026
Epoch: 002, Loss: 0.6820, Val: 0.6589, Test: 0.6913
Epoch: 003, Loss: 0.7065, Val: 0.6619, Test: 0.6967
Epoch: 004, Loss: 0.6766, Val: 0.6686, Test: 0.7069
Epoch: 005, Loss: 0.6842, Val: 0.6716, Test: 0.7128
Epoch: 006, Loss: 0.6876, Val: 0.6637, Test: 0.7132
Epoch: 007, Loss: 0.6881, Val: 0.6471, Test: 0.7009
Epoch: 008, Loss: 0.6867, Val: 0.6317, Test: 0.6859
Epoch: 009, Loss: 0.6829, Val: 0.6240, Test: 0.6767
Epoch: 010, Loss: 0.6765, Val: 0.6223, Test: 0.6720
Epoch: 011, Loss: 0.6715, Val: 0.6208, Test: 0.6684
Epoch: 012, Loss: 0.6759, Val: 0.6204, Test: 0.6640
Epoch: 013, Loss: 0.6687, Val: 0.6272, Test: 0.6656
Epoch: 014, Loss: 0.6621, Val: 0.6488, Test: 0.6778
Epoch: 015, Loss: 0.6593, Val: 0.6748, Test: 0.6907
Epoch: 016, Loss: 0.6534, Val: 0.6824, Test: 0.6923
Epoch: 017, Loss: 0.6477, Val: 0.6796, Test: 0.6867
Epoch: 018, Loss: 0.6389, Val: 0.6847, Test: 0.6888
Epoch: 019, Loss: 0.6332, Val: 0.7155, Test: 0.7115
Epoch: 020, Loss: 0.6217, Val: 0.7487, Test: 0.7430
Epoch: 021, Loss: 0.6060, Val: 0.7645, Test: 0.7582
Epoch: 022, Loss: 0.5993, Val: 0.7650, Test: 0.7574
Epoch: 023, Loss: 0.5837, Val: 0.7632, Test: 0.7550
Epoch: 024, Loss: 0.5719, Val: 0.7612, Test: 0.7530
Epoch: 025, Loss: 0.5654, Val: 0.7565, Test: 0.7518
Epoch: 026, Loss: 0.5697, Val: 0.7574, Test: 0.7534
Epoch: 027, Loss: 0.5676, Val: 0.7610, Test: 0.7576
Epoch: 028, Loss: 0.5551, Val: 0.7629, Test: 0.7634
Epoch: 029, Loss: 0.5446, Val: 0.7682, Test: 0.7723
Epoch: 030, Loss: 0.5422, Val: 0.7774, Test: 0.7848
Epoch: 031, Loss: 0.5259, Val: 0.7896, Test: 0.7988
Epoch: 032, Loss: 0.5277, Val: 0.8005, Test: 0.8127
Epoch: 033, Loss: 0.5218, Val: 0.8135, Test: 0.8245
Epoch: 034, Loss: 0.5156, Val: 0.8234, Test: 0.8342
Epoch: 035, Loss: 0.5057, Val: 0.8285, Test: 0.8414
Epoch: 036, Loss: 0.4981, Val: 0.8314, Test: 0.8462
Epoch: 037, Loss: 0.4984, Val: 0.8302, Test: 0.8459
Epoch: 038, Loss: 0.4960, Val: 0.8332, Test: 0.8489
Epoch: 039, Loss: 0.4873, Val: 0.8381, Test: 0.8555
Epoch: 040, Loss: 0.4883, Val: 0.8418, Test: 0.8609
Epoch: 041, Loss: 0.4993, Val: 0.8427, Test: 0.8615
Epoch: 042, Loss: 0.4852, Val: 0.8452, Test: 0.8616
Epoch: 043, Loss: 0.4718, Val: 0.8474, Test: 0.8640
Epoch: 044, Loss: 0.4768, Val: 0.8492, Test: 0.8679
Epoch: 045, Loss: 0.4708, Val: 0.8472, Test: 0.8688
Epoch: 046, Loss: 0.4726, Val: 0.8457, Test: 0.8680
Epoch: 047, Loss: 0.4729, Val: 0.8500, Test: 0.8698
Epoch: 048, Loss: 0.4726, Val: 0.8517, Test: 0.8705
Epoch: 049, Loss: 0.4730, Val: 0.8527, Test: 0.8722
Epoch: 050, Loss: 0.4715, Val: 0.8521, Test: 0.8734
Epoch: 051, Loss: 0.4667, Val: 0.8547, Test: 0.8756
Epoch: 052, Loss: 0.4609, Val: 0.8577, Test: 0.8784
Epoch: 053, Loss: 0.4632, Val: 0.8607, Test: 0.8829
Epoch: 054, Loss: 0.4612, Val: 0.8626, Test: 0.8862
Epoch: 055, Loss: 0.4591, Val: 0.8646, Test: 0.8878
Epoch: 056, Loss: 0.4568, Val: 0.8644, Test: 0.8874
Epoch: 057, Loss: 0.4569, Val: 0.8656, Test: 0.8874
Epoch: 058, Loss: 0.4568, Val: 0.8688, Test: 0.8897
Epoch: 059, Loss: 0.4516, Val: 0.8721, Test: 0.8929
Epoch: 060, Loss: 0.4567, Val: 0.8729, Test: 0.8942
Epoch: 061, Loss: 0.4625, Val: 0.8742, Test: 0.8938
Epoch: 062, Loss: 0.4547, Val: 0.8729, Test: 0.8919
Epoch: 063, Loss: 0.4479, Val: 0.8723, Test: 0.8927
Epoch: 064, Loss: 0.4517, Val: 0.8728, Test: 0.8962
Epoch: 065, Loss: 0.4517, Val: 0.8719, Test: 0.8972
Epoch: 066, Loss: 0.4538, Val: 0.8726, Test: 0.8962
Epoch: 067, Loss: 0.4532, Val: 0.8718, Test: 0.8944
Epoch: 068, Loss: 0.4540, Val: 0.8725, Test: 0.8937
Epoch: 069, Loss: 0.4542, Val: 0.8734, Test: 0.8953
Epoch: 070, Loss: 0.4487, Val: 0.8726, Test: 0.8967
Epoch: 071, Loss: 0.4497, Val: 0.8727, Test: 0.8973
Epoch: 072, Loss: 0.4539, Val: 0.8694, Test: 0.8949
Epoch: 073, Loss: 0.4478, Val: 0.8703, Test: 0.8937
Epoch: 074, Loss: 0.4449, Val: 0.8737, Test: 0.8945
Epoch: 075, Loss: 0.4486, Val: 0.8770, Test: 0.8968
Epoch: 076, Loss: 0.4491, Val: 0.8724, Test: 0.8970
Epoch: 077, Loss: 0.4431, Val: 0.8678, Test: 0.8957
Epoch: 078, Loss: 0.4447, Val: 0.8688, Test: 0.8952
Epoch: 079, Loss: 0.4540, Val: 0.8704, Test: 0.8943
Epoch: 080, Loss: 0.4548, Val: 0.8741, Test: 0.8955
Epoch: 081, Loss: 0.4468, Val: 0.8746, Test: 0.8985
Epoch: 082, Loss: 0.4495, Val: 0.8727, Test: 0.8994
Epoch: 083, Loss: 0.4473, Val: 0.8708, Test: 0.8990
Epoch: 084, Loss: 0.4464, Val: 0.8715, Test: 0.8976
Epoch: 085, Loss: 0.4376, Val: 0.8755, Test: 0.8977
Epoch: 086, Loss: 0.4455, Val: 0.8762, Test: 0.8993
Epoch: 087, Loss: 0.4442, Val: 0.8727, Test: 0.9004
Epoch: 088, Loss: 0.4411, Val: 0.8726, Test: 0.9009
Epoch: 089, Loss: 0.4445, Val: 0.8760, Test: 0.9010
Epoch: 090, Loss: 0.4474, Val: 0.8780, Test: 0.9002
Epoch: 091, Loss: 0.4468, Val: 0.8754, Test: 0.9009
Epoch: 092, Loss: 0.4470, Val: 0.8712, Test: 0.9015
Epoch: 093, Loss: 0.4467, Val: 0.8680, Test: 0.9006
Epoch: 094, Loss: 0.4454, Val: 0.8720, Test: 0.9019
Epoch: 095, Loss: 0.4355, Val: 0.8761, Test: 0.9028
Epoch: 096, Loss: 0.4486, Val: 0.8749, Test: 0.9013
Epoch: 097, Loss: 0.4418, Val: 0.8695, Test: 0.8999
Epoch: 098, Loss: 0.4396, Val: 0.8651, Test: 0.9002
Epoch: 099, Loss: 0.4365, Val: 0.8684, Test: 0.9034
Epoch: 100, Loss: 0.4428, Val: 0.8720, Test: 0.9050
Final Test: 0.9002
torch.Size([2, 3262820])

8. 整体代码

import torch
from sklearn.metrics import roc_auc_score

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
                      add_negative_train_samples=False),
])
dataset = Planetoid('/data/wanghuijuan/pyg_data/Planetoid', name='Cora', transform=transform)
train_data, val_data, test_data = dataset[0]


class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()


model = Net(dataset.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()


def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)

    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
        num_neg_samples=train_data.edge_label_index.size(1), method='sparse')

    edge_label_index = torch.cat(
        [train_data.edge_label_index, neg_edge_index],
        dim=-1,
    )
    edge_label = torch.cat([
        train_data.edge_label,
        train_data.edge_label.new_zeros(neg_edge_index.size(1))
    ], dim=0)

    out = model.decode(z, edge_label_index).view(-1)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
    return loss


@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())


best_val_auc = final_test_auc = 0
for epoch in range(1, 101):
    loss = train()
    val_auc = test(val_data)
    test_auc = test(test_data)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
          f'Test: {test_auc:.4f}')

print(f'Final Test: {final_test_auc:.4f}')

z = model.encode(test_data.x, test_data.edge_index)
final_edge_index = model.decode_all(z)

print(final_edge_index.size())

你可能感兴趣的:(人工智能学习笔记,pytorch,深度学习,GNN,PyG,链路预测)