HAN的原理请见:WWW 2019 | HAN:异质图注意力网络。
导入数据:
path = os.path.abspath(os.path.dirname(os.getcwd())) + '\data\DBLP'
dataset = DBLP(path)
graph = dataset[0]
print(graph)
输出如下:
HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={ x=[14328, 4231] },
term={ x=[7723, 50] },
conference={ num_nodes=20 },
(author, to, paper)={ edge_index=[2, 19645] },
(paper, to, author)={ edge_index=[2, 19645] },
(paper, to, term)={ edge_index=[2, 85810] },
(paper, to, conference)={ edge_index=[2, 14328] },
(term, to, paper)={ edge_index=[2, 85810] },
(conference, to, paper)={ edge_index=[2, 14328] }
)
可以发现,DBLP数据集中有作者(author)、论文(paper)、术语(term)以及会议(conference)四种类型的节点。DBLP中包含14328篇论文(paper), 4057位作者(author), 20个会议(conference), 7723个术语(term)。作者分为四个领域:数据库、数据挖掘、机器学习、信息检索。
任务:对author节点进行分类,一共4类。
由于conference节点没有特征,因此需要预先设置特征:
graph['conference'].x = torch.ones((graph['conference'].num_nodes, 1))
所有conference节点的特征都初始化为[1]
。
获取一些有用的数据:
num_classes = torch.max(graph['author'].y).item() + 1
train_mask, val_mask, test_mask = graph['author'].train_mask, graph['author'].val_mask, graph['author'].test_mask
y = graph['author'].y
首先导入包:
from torch_geometric.nn import HANConv
于是模型搭建如下:
class HAN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(HAN, self).__init__()
# H, D = self.heads, self.out_channels // self.heads
self.conv1 = HANConv(in_channels, hidden_channels, graph.metadata(), heads=8)
self.conv2 = HANConv(hidden_channels, out_channels, graph.metadata(), heads=4)
def forward(self, data):
x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
x = self.conv1(x_dict, edge_index_dict)
x = self.conv2(x, edge_index_dict)
x = F.softmax(x['author'], dim=1)
return x
输出一下模型:
model = HAN(-1, 64, num_classes).to(device)
HAN(
(conv1): HANConv(64, heads=8)
(conv2): HANConv(4, heads=4)
)
查看官方文档中HANConv的输入输出要求:
可以发现,HANConv中需要输入的是节点特征字典x_dict
和邻接关系字典edge_index_dict
。
因此有:
x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
x = self.conv1(x_dict, edge_index_dict)
此时我们不妨输出一下x['author']
及其size:
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0969, 0.0601, 0.0000, ..., 0.0000, 0.0000, 0.0251],
[0.0000, 0.0000, 0.0000, ..., 0.1288, 0.0000, 0.0602],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0096, 0.0000, 0.0240],
[0.0000, 0.0000, 0.0000, ..., 0.0096, 0.0000, 0.0240],
[0.0801, 0.0558, 0.0837, ..., 0.0277, 0.0347, 0.0000]],
device='cuda:0', grad_fn=<SumBackward1>)
torch.Size([4057, 64])
此时的x一共4057行,每一行表示一个author节点经过第一层卷积更新后的状态向量。
那么同理,由于:
x = self.conv2(x, edge_index_dict)
所以经过第二层卷积后得到的x['author']
的size应该为:
torch.Size([4057, 4])
即每个author节点的维度为4的状态向量。
由于我们需要进行4分类,所以最后需要加上一个softmax:
x = F.softmax(x, dim=1)
dim=1表示对每一行进行运算,最终每一行之和加起来为1,也就表示了该节点为每一类的概率。输出此时的x:
tensor([[0.2591, 0.2539, 0.2435, 0.2435],
[0.3747, 0.2067, 0.2029, 0.2157],
[0.2986, 0.2338, 0.2338, 0.2338],
...,
[0.2740, 0.2453, 0.2403, 0.2403],
[0.2740, 0.2453, 0.2403, 0.2403],
[0.3414, 0.2195, 0.2195, 0.2195]], device='cuda:0',
grad_fn=<SoftmaxBackward0>)
在训练时,我们首先利用前向传播计算出输出:
f = model(graph)
f即为最终得到的每个节点的4个概率值,但在实际训练中,我们只需要计算出训练集的损失,所以损失函数这样写:
loss = loss_function(f[train_mask], y[train_mask])
然后计算梯度,反向更新!
训练时返回验证集上表现最优的模型:
def train():
model = HAN(-1, 64, num_classes).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
loss_function = torch.nn.CrossEntropyLoss().to(device)
min_epochs = 5
best_val_acc = 0
final_best_acc = 0
model.train()
for epoch in range(100):
f = model(graph)
loss = loss_function(f[train_mask], y[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# validation
val_acc, val_loss = test(model, val_mask)
test_acc, test_loss = test(model, test_mask)
if epoch + 1 > min_epochs and val_acc > best_val_acc:
best_val_acc = val_acc
final_best_acc = test_acc
print('Epoch {:3d} train_loss {:.5f} val_acc {:.3f} test_acc {:.3f}'
.format(epoch, loss.item(), val_acc, test_acc))
return final_best_acc
def test(model, mask):
model.eval()
with torch.no_grad():
out = model(graph)
loss_function = torch.nn.CrossEntropyLoss().to(device)
loss = loss_function(out[mask], y[mask])
_, pred = out.max(dim=1)
correct = int(pred[mask].eq(y[mask]).sum().item())
acc = correct / int(test_mask.sum())
return acc, loss.item()
数据集采用DBLP网络,训练100轮,分类正确率为78.54%:
HAN Accuracy: 0.7853853239177156
代码地址:GNNs-for-Node-Classification。原创不易,下载时请给个follow和star!感谢!!