使用 PyTorch Geometric 和 Heterogeneous Graph Transformer 实现异构图上的节点分类
在二部图上应用GTN算法(使用torch_geometric的库HGTConv);
导入所需的 PyTorch 和 PyTorch Geometric 库。
定义 x1 和 x2 两种不同类型节点的特征,分别有 1000 个和 500 个节点,每个节点有两维特征。
随机生成两种边 e1 和 e2 的索引(edge index)和权重(edge weight),其中 e1 从 n1 到 n2,e2 从 n2 到 n1。
定义异构图的元数据字典 meta_dict,其中 ‘n1’ 和 ‘n2’ 分别表示两种节点类型,而 (‘n1’, ‘e1’, ‘n2’) 表示从类型 ‘n1’ 的节点到类型 ‘n2’ 的节点有一条边,这条边的索引和权重分别为 edge_index_e1 和 edge_weight_e1。
利用元数据字典 meta_dict 创建异构图数据对象 data,并将节点特征和边索引添加到该对象中。
定义异构元数据列表 meta_list,其中包含所有节点类型和边类型的名称信息。
定义 HGTConv 层,并指定输入通道数、输出通道数、异构元数据列表以及头数等超参数。
将节点特征和边索引转换为字典形式,并利用 HGTConv
应用 HGTConv 到输入数据,得到输出结果 output_dict,其中包含了处理后的节点特征。最后打印输出 n1 和 n2 节点的输出形状。
以下代码可以直接运行
import torch
from torch_geometric.data import Data, HeteroData
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import HGTConv
# 定义节点特征
x1 = torch.randn(1000, 2)
x2 = torch.randn(500, 2)
# 定义边索引(edge index)以及边权重(edge weight)
edge_index_e1 = torch.cat((torch.randint(0, 1000, size=(1, 4000)),torch.randint(0, 500, size=(1, 4000))),dim=0)
edge_weight_e1 = torch.rand(4000)
edge_index_e2=torch.flip(edge_index_e1, (0,))
# 定义元数据字典,描述异构图的结构
meta_dict = {
'n1': {'num_nodes': x1.shape[0], 'num_features': x1.shape[1]},
'n2': {'num_nodes': x2.shape[0], 'num_features': x2.shape[1]},
('n1', 'e1', 'n2'): {'edge_index': edge_index_e1, 'edge_weight': edge_weight_e1},
}
# 创建异构图数据对象
data = HeteroData(meta_dict)
# 将节点特征和边索引添加到异构图对象中
data['n1'].x = x1
data['n2'].x = x2
data[('n1', 'e1', 'n2')].edge_index = edge_index_e1
data[('n2', 'e1', 'n1')].edge_index = edge_index_e2
# 定义异构元数据列表
meta_list= (['n1', 'n2'], [('n1', 'e1', 'n2'), ('n2', 'e1', 'n1')])
# 定义 HGTConv 层
in_channels = {
'n1': x1.shape[1],
'n2': x2.shape[1],
}
out_channels = 16
heads = 4
conv = HGTConv(in_channels=in_channels, out_channels=out_channels, metadata=meta_list,heads=heads)
# 将输入数据转换为字典形式
x_dict = {ntype: data[ntype].x for ntype in data.node_types}
edge_index_dict = {}
for etype in data.edge_types:
edge_index_dict[etype] = data[etype].edge_index
# 应用 HGTConv 到输入数据
output_dict = conv(x_dict, edge_index_dict)
print(output_dict['n1'].shape)
print(output_dict['n2'].shape)
之后如果是节点分类则:
output_dict的n1,n2特征编码分别接全连接层对应y1,y2
之后如果是链路预测则:
output_dict的n1,n2特征编码按照链路进行合并,进而预测
data = HeteroData(meta_dict) 创建异构图对象
edge_index_e2=torch.flip(edge_index_e1, (0,)) 创建逆向的边,由于是二部图无向图所以需要