PyG搭建GNN实现链接回归预测

前言

前面写了一些有关GNN的各种图任务,主要是节点分类以及链接预测:

  1. PyG搭建GCN前的准备:了解PyG中的数据格式
  2. PyG搭建GCN实现节点分类(GCNConv参数详解)
  3. PyG搭建GAT实现节点分类
  4. PyG利用MessagePassing搭建GCN实现节点分类
  5. 搭建SGC实现引文网络节点预测(PyTorch+PyG)
  6. PyG搭建R-GCN实现节点分类
  7. PyG搭建异质图注意力网络HAN实现DBLP节点分类
  8. 链接预测中训练集、验证集以及测试集的划分(以PyG的RandomLinkSplit为例)
  9. PyG搭建GCN实现链接预测
  10. PyG搭建R-GCN实现链接预测

其中链接预测主要指预测某对节点间是否存在边,是一个二分类任务,即有(1)/没有(0)边。而链接回归,顾名思义,就是预测某对节点构成的边上的某一个具体数值

数据集

链接回归可以存在于一个图中,也可以多个图同时进行训练。本文的数据集为多个图,给定部分图参与训练,然后预测指定图上所有链接上的值。

单个图如下所示:

Data(x=[3, 75], edge_index=[2, 4], edge_values=[4])

任务是预测图上所有边上的值,即edge_value

图的尺寸各不相同,为了实现批量训练,使用PyG提供的DataLoader对多个图进行批量封装:

train_loader = DataLoader(datas[:num_train], batch_size=batch_size, shuffle=True, drop_last=False)
val_loader = DataLoader(datas[num_train:num_train+num_val], batch_size=batch_size, shuffle=True, drop_last=False)
test_loader = DataLoader(datas[num_train+num_val:], batch_size=batch_size, shuffle=True, drop_last=False)

模型

搭建一个GCN实现链接回归预测:

class GCN(torch.nn.Module):
    def __init__(self, in_feats, h_feats, out_feats):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_feats, h_feats)
        self.conv2 = GCNConv(h_feats, out_feats)
        self.fc = nn.Sequential(
            nn.Linear(2 * out_feats, out_feats),
            nn.ReLU(),
            nn.Linear(out_feats, 1)
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = x.float()
        x = F.elu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        # 
        x_src = x[edge_index[0]]
        x_dst = x[edge_index[1]]
        edge_x = torch.cat((x_src, x_dst), dim=1)
        out = self.fc(edge_x)
        out = torch.flatten(out)

        return out

可以看出,我们首先利用GCN得到了图中所有节点的嵌入表示x,然后根据x取出图中所有链接两端的节点的表示向量:

x_src = x[edge_index[0]]
x_dst = x[edge_index[1]]

然后,为了预测链接上的值,我们采用了一种最简单的方式:

edge_x = torch.cat((x_src, x_dst), dim=1)
out = self.fc(edge_x)

即将链接两端节点的向量进行拼接,然后将拼接后的向量经过一个线性层以得到链接上的预测值。

最终得到的out包含了图中所有链接上的预测值。

训练/测试

模型训练:

def train(train_loader, val_loader, test_loader, m, n):
    model = GCN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
    loss_function = torch.nn.MSELoss().to(device)
    scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
    min_epochs = 10
    min_val_loss = 5
    best_model = None
    best_mape = 0
    model.train()
    for epoch in tqdm(range(50)):
        train_losses = []
        for tr in train_loader:
            tr = tr.to(device)
            out = model(tr)
            optimizer.zero_grad()
            loss = loss_function(out, tr.edge_values.float())
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_losses.append(loss.item())
        # validation
        val_loss, test_mape, test_loader = test(model, val_loader, test_loader, m, n)
        if val_loss < min_val_loss and epoch + 1 > min_epochs:
            min_val_loss = val_loss
            best_model = copy.deepcopy(model)
            best_mape = test_mape
        print('Epoch {:03d} train_loss {:.4f} val_loss {:.4f} test mape {:.4f}'.format(epoch,
                                                                                       np.mean(train_losses),
                                                                                       val_loss, test_mape))

    print('best mape:', best_mape)
    # 反归一化
    for te in test_loader:
        t = te.edge_values.cpu().numpy()
        t = (m - n) * t + n
        te.edge_values = torch.FloatTensor(t)
        
    return best_model

你可能感兴趣的:(PyG,GNN,图神经网络,链接回归,链接预测,GCN,PyG)