图神经网络pyg,使用Cora数据集

关于Cora数据集的手动导入问题,见博客https://blog.csdn.net/qq_42969578/article/details/121450994?spm=1001.2014.3001.5501

关于Core数据集的一些问题

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#导入数据集
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root = "data/", name = "Cora")
data = dataset[0]
print(data.keys)
print(data.y)
>>>['test_mask', 'train_mask', 'edge_index', 'y', 'val_mask', 'x']
tensor([3, 4, 4,  ..., 3, 3, 3])

这里:

  • train_mask: 训练集的mask向量,用于标识哪些结点属于训练集
  • test_mask: 测试集的mask向量,用于标识哪些结点属于测试集
  • val_mask: 验证集的mask向量,用于标识哪些结点数据验证集
  • x: 输入的特征矩阵
  • y: 结点标签
#implement a two layer GCN
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)
        
    def forward(self,data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training = self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim = 1)
# model = GCN()
# print(model.parameters())
#======================================
#开始训练
device = "cuda:9"
model = GCN().to(device)	#把模型传入GPU
data = dataset[0].to(device)	#把数据传入GPU
#优化器需要传入模型参数,学习率
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01, weight_decay = 5e-4)
model.train()
for epoch in range(300):
    optimizer.zero_grad()	#把梯度设置为0
    out = model(data)	#预测
    #这里data.train_mask可以筛选训练集的结点
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])	#计算loss
    loss.backward()	#计算梯度
    optimizer.step()	#更新梯度
    print("loss: {}".format(loss))

#=====================================
    #评估
model.eval()
pred = model(data).argmax(dim = 1)
#这里data.test_mask可以筛选出测试集的结点
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
#data.test_mask.sum()方法可以计算出测试集的个数
acc = int(correct) / int(data.test_mask.sum())
print("acc: ", acc)

你可能感兴趣的:(图神经网络,神经网络,深度学习)