首先感谢datawhale 的GNN课程,非常精彩。
GNN/Markdown版本/6-1-数据完整存于内存的数据集类.md
Task04 数据完整存储与内存的数据集类+节点预测与边预测任务实践
1 知识梳理
1.1 使用数据集的一般过程
从网络上下载数据原始文件;
对数据原始文件做处理,为每一个图样本生成一个**Data对象**;
对每一个Data对象执行数据处理,使其转换成新的Data对象;
过滤Data对象;
保存Data对象到文件;
获取Data对象,在每一次获取Data对象时,都先对Data对象做数据变换(于是获取到的是数据变换后的Data对象)。
1.2 边预测任务
思路:生成负样本,使得正负样本数量平衡
使用train_test_split_edges函数,采样得到负样本,并将正负样本分成训练集、验证集和测试集
2 实战练习
2.1 PlanetoidPubMed数据集类的构造 (CORA数据集训练滴)dataset = Planetoid(root='./tmp/cora', name='Cora')
dataset = Planetoid(root='./tmp/cora', name='Cora')
print('数据类别个数:', dataset.num_classes)
print('节点数:', dataset[0].num_nodes)
print('边数:', dataset[0].num_edges)
print('节点特征维度:', dataset[0].num_features)
importos.pathasospimporttorchfromtorch_geometric.dataimport(InMemoryDataset,download_url)fromtorch_geometric.ioimportread_planetoid_dataclassPlanetoidPubMed(InMemoryDataset):r""" 节点代表文章,边代表引文关系。
训练、验证和测试的划分通过二进制掩码给出。
参数:
root (string): 存储数据集的文件夹的路径
transform (callable, optional): 数据转换函数,每一次获取数据时被调用。
pre_transform (callable, optional): 数据转换函数,数据保存到文件前被调用。
"""# url = 'https://github.com/kimiyoung/planetoid/raw/master/data'url='https://gitee.com/jiajiewu/planetoid/raw/master/data'def__init__(self,root,transform=None,pre_transform=None):super(PlanetoidPubMed,self).__init__(root,transform,pre_transform)self.data,self.slices=torch.load(self.processed_paths[0])@propertydefraw_dir(self):returnosp.join(self.root,'raw')@propertydefprocessed_dir(self):returnosp.join(self.root,'processed')@propertydefraw_file_names(self):names=['x','tx','allx','y','ty','ally','graph','test.index']return['ind.pubmed.{}'.format(name)fornameinnames]@propertydefprocessed_file_names(self):return'data.pt'defdownload(self):fornameinself.raw_file_names:download_url('{}/{}'.format(self.url,name),self.raw_dir)defprocess(self):data=read_planetoid_data(self.raw_dir,'pubmed')data=dataifself.pre_transformisNoneelseself.pre_transform(data)torch.save(self.collate([data]),self.processed_paths[0])def__repr__(self):return'{}()'.format(self.name)Copy to clipboardErrorCopied
程序运行流程:
检查数据原始文件是否已经下载
检查数据是否经过处理:检查数据变换的方法、检查样本过滤的方法、检查是否处理好数据
dataset=PlanetoidPubMed('dataset/PlanetoidPubMed')print('数据类别个数:',dataset.num_classes)print('节点数:',dataset[0].num_nodes)print('边数:',dataset[0].num_edges)print('节点特征维度:',dataset[0].num_features)Copy to clipboardErrorCopied
数据类别个数: 3
节点数: 19717
边数: 88648
节点特征维度: 500Copy to clipboardErrorCopied
2.2 使用GAT图神经网络进行节点预测
fromtorch_geometric.nnimportGATConv,Sequentialfromtorch.nnimportLinear,ReLUimporttorch.nn.functionalasFclassGAT(torch.nn.Module):def__init__(self,num_features,hidden_channels_list,num_classes):super(GAT,self).__init__()torch.manual_seed(12345)hns=[num_features]+hidden_channels_list conv_list=[]foridxinrange(len(hidden_channels_list)):conv_list.append((GATConv(hns[idx],hns[idx+1]),'x, edge_index -> x'))conv_list.append(ReLU(inplace=True),)self.convseq=Sequential('x, edge_index',conv_list)self.linear=Linear(hidden_channels_list[-1],num_classes)defforward(self,x,edge_index):x=self.convseq(x,edge_index)x=F.dropout(x,p=0.5,training=self.training)x=self.linear(x)returnxCopy to clipboardErrorCopied
deftrain():model.train()optimizer.zero_grad()# Clear gradients.out=model(data.x,data.edge_index)# Perform a single forward pass.# Compute the loss solely based on the training nodes.loss=criterion(out[data.train_mask],data.y[data.train_mask])loss.backward()# Derive gradients.optimizer.step()# Update parameters based on gradients.returnlossdeftest():model.eval()out=model(data.x,data.edge_index)pred=out.argmax(dim=1)# Use the class with highest probability.test_correct=pred[data.test_mask]==data.y[data.test_mask]# Check against ground-truth labels.test_acc=int(test_correct.sum())/int(data.test_mask.sum())# Derive ratio of correct predictions.returntest_accCopy to clipboardErrorCopied
importmatplotlib.pyplotaspltfromsklearn.manifoldimportTSNE%matplotlib inlinedefvisualize(h,color):z=TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())plt.figure(figsize=(10,10))plt.xticks([])plt.yticks([])plt.scatter(z[:,0],z[:,1],s=70,c=color.cpu(),cmap="Set2")plt.show()Copy to clipboardErrorCopied
fromtorch_geometric.transformsimportNormalizeFeaturesdataset=PlanetoidPubMed(root='dataset/PlanetoidPubMed/',transform=NormalizeFeatures())print('dataset.num_features:',dataset.num_features)device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')data=dataset[0].to(device)model=GAT(num_features=dataset.num_features,hidden_channels_list=[200,100],num_classes=dataset.num_classes).to(device)print(model)optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)criterion=torch.nn.CrossEntropyLoss()forepochinrange(1,201):loss=train()ifepoch%10==0:print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')test_acc=test()print(f'Test Accuracy: {test_acc:.4f}')model.eval()out=model(data.x,data.edge_index)visualize(out,color=data.y)Copy to clipboardErrorCopied
dataset.num_features: 500
GAT(
(convseq): Sequential(
(0): GATConv(500, 200, heads=1)
(1): ReLU(inplace=True)
(2): GATConv(200, 100, heads=1)
(3): ReLU(inplace=True)
)
(linear): Linear(in_features=100, out_features=3, bias=True)
)
dataset.num_features: 1433
GAT(
(convseq): Sequential(
(0): GATConv(1433, 200, heads=1)
(1): ReLU(inplace=True)
(2): GATConv(200, 100, heads=1)
(3): ReLU(inplace=True)
)
(linear): Linear(in_features=100, out_features=7, bias=True)
)
Epoch: 010, Loss: 1.7378
Epoch: 020, Loss: 0.7310
Epoch: 030, Loss: 0.2087
Epoch: 040, Loss: 0.0610
Epoch: 050, Loss: 0.0477
Epoch: 060, Loss: 0.0368
Epoch: 070, Loss: 0.0360
Epoch: 080, Loss: 0.0354
Epoch: 090, Loss: 0.0310
Epoch: 100, Loss: 0.0279
Epoch: 110, Loss: 0.0263
Epoch: 120, Loss: 0.0281
Epoch: 130, Loss: 0.0349
Epoch: 140, Loss: 0.0246
Epoch: 150, Loss: 0.0298
Epoch: 160, Loss: 0.0218
Epoch: 170, Loss: 0.0328
Epoch: 180, Loss: 0.0199
Epoch: 190, Loss: 0.0223
Epoch: 200, Loss: 0.0330
Test Accuracy: 0.7510
2.3 使用两层GCNConv神经网络进行边预测
fromtorch_geometric.datasetsimportPlanetoidfromtorch_geometric.utilsimporttrain_test_split_edgesimporttorch_geometric.transformsasTdevice=torch.device('cuda'iftorch.cuda.is_available()else'cpu')dataset='Cora'path=osp.join('dataset',dataset)# 读取Cora数据集dataset=Planetoid(path,dataset,transform=T.NormalizeFeatures())data=dataset[0]ground_truth_edge_index=data.edge_index.to(device)data.train_mask=data.val_mask=data.test_mask=data.y=None# 划分数据集data=train_test_split_edges(data)data=data.to(device)Copy to clipboardErrorCopied
fromtorch_geometric.nnimportGCNConv# 构建神经网络classNet(torch.nn.Module):def__init__(self,in_channels,out_channels):super(Net,self).__init__()self.conv1=GCNConv(in_channels,128)self.conv2=GCNConv(128,out_channels)defencode(self,x,edge_index):x=self.conv1(x,edge_index)x=x.relu()returnself.conv2(x,edge_index)defdecode(self,z,pos_edge_index,neg_edge_index):edge_index=torch.cat([pos_edge_index,neg_edge_index],dim=-1)return(z[edge_index[0]]*z[edge_index[1]]).sum(dim=-1)defdecode_all(self,z):prob_adj=z @ z.t()return(prob_adj>0).nonzero(as_tuple=False).t()Copy to clipboardErrorCopied
fromtorch_geometric.utilsimportnegative_samplingimporttorch.nn.functionalasF# 得到边的类别{0,1}defget_link_labels(pos_edge_index,neg_edge_index):num_links=pos_edge_index.size(1)+neg_edge_index.size(1)link_labels=torch.zeros(num_links,dtype=torch.float)link_labels[:pos_edge_index.size(1)]=1.returnlink_labelsdeftrain(data,model,optimizer):model.train()# 进行负采样,使得样本数一致neg_edge_index=negative_sampling(edge_index=data.train_pos_edge_index,num_nodes=data.num_nodes,num_neg_samples=data.train_pos_edge_index.size(1))optimizer.zero_grad()z=model.encode(data.x,data.train_pos_edge_index)link_logits=model.decode(z,data.train_pos_edge_index,neg_edge_index)link_labels=get_link_labels(data.train_pos_edge_index,neg_edge_index).to(data.x.device)loss=F.binary_cross_entropy_with_logits(link_logits,link_labels)loss.backward()optimizer.step()returnlossCopy to clipboardErrorCopied
[email protected]_grad()deftest(data,model):model.eval()z=model.encode(data.x,data.train_pos_edge_index)results=[]forprefixin['val','test']:pos_edge_index=data[f'{prefix}_pos_edge_index']neg_edge_index=data[f'{prefix}_neg_edge_index']link_logits=model.decode(z,pos_edge_index,neg_edge_index)# 得到正负类别概率link_probs=link_logits.sigmoid()link_labels=get_link_labels(pos_edge_index,neg_edge_index)results.append(roc_auc_score(link_labels.cpu(),link_probs.cpu()))returnresultsCopy to clipboardErrorCopied
model=Net(dataset.num_features,64).to(device)optimizer=torch.optim.Adam(params=model.parameters(),lr=0.01)best_val_auc=test_auc=0forepochinrange(1,101):loss=train(data,model,optimizer)val_auc,tmp_test_auc=test(data,model)ifval_auc>best_val_auc:best_val_auc=val_auc test_auc=tmp_test_aucifepoch%10==0:print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, 'f'Test: {test_auc:.4f}')z=model.encode(data.x,data.train_pos_edge_index)final_edge_index=model.decode_all(z)print('ground truth edge shape:',ground_truth_edge_index.shape)print('final edge shape:',final_edge_index.shape)Copy to clipboardErrorCopied