Note:
Click here to download the full example code
Authors: Hao Zhang, Mufei Li
, Minjie Wang Zheng Zhang
在本教程中,您将学习图注意力网络(GAT)以及如何在PyTorch中实现它。您还可以学习可视化并了解注意力机制所学到的知识。
图卷积网络(GCN)中描述的研究表明,结合局部图结构和节点级特征可以在节点分类任务上产生良好的性能。但是,GCN聚合的方式取决于结构,这可能会损害其通用性。
一种解决方法是按研究论文GraphSAGE中所述简单地平均所有邻居节点特征。但是,Graph Attention Network提出了另一种类型的聚合。GAN以关注方式使用具有特征依赖和无结构归一化的加权邻居特征。
GAT和GCN之间的主要区别在于如何汇总来自一跳社区的信息。
对于GCN,图卷积运算会生成邻居节点特征的归一化总和。
h i ( l + 1 ) = σ ( ∑ j ∈ N ( i ) 1 c i j W ( l ) h j ( l ) ) h_i^{(l+1)}=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\frac{1}{c_{ij}} W^{(l)}h^{(l)}_j}\right) hi(l+1)=σ⎝⎛j∈N(i)∑cij1W(l)hj(l)⎠⎞
哪里 N ( i ) \mathcal{N}(i) N(i)是其一跳邻居的集合(包括 v i v_i vi在集合中,只需向每个节点添加一个自环), c i j = ∣ N ( i ) ∣ ∣ N ( j ) ∣ c_{ij}=\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|} cij=∣N(i)∣∣N(j)∣是基于图结构的归一化常数, σ \sigma σ是激活功能(GCN使用ReLU),并且 W ( l ) W^{(l)} W(l)是用于节点特征转换的共享权重矩阵。GraphSAGE中提出的另一种模型 采用相同的更新规则,只是它们设置了 c i j = ∣ N ( i ) ∣ c_{ij}=|\mathcal{N}(i)| cij=∣N(i)∣。
GAT引入了注意力机制,以替代静态归一化卷积运算。以下是计算节点嵌入的方程式 h i ( l + 1 ) h_i^{(l+1)} hi(l+1)层数 l + 1 l+1 l+1从图层的嵌入 l l l。
z i ( l ) = W ( l ) h i ( l ) , ( 1 ) e i j ( l ) = LeakyReLU ( a ⃗ ( l ) T ( z i ( l ) ∣ ∣ z j ( l ) ) ) , ( 2 ) α i j ( l ) = exp ( e i j ( l ) ) ∑ k ∈ N ( i ) exp ( e i k ( l ) ) , ( 3 ) h i ( l + 1 ) = σ ( ∑ j ∈ N ( i ) α i j ( l ) z j ( l ) ) , ( 4 ) z_i^{(l)}=W^{(l)}h_i^{(l)},(1)\\ e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}||z_j^{(l)})),(2)\\ \alpha_{ij}^{(l)}=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},(3)\\ h_i^{(l+1)}=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),(4) zi(l)=W(l)hi(l),(1)eij(l)=LeakyReLU(a(l)T(zi(l)∣∣zj(l))),(2)αij(l)=∑k∈N(i)exp(eik(l))exp(eij(l)),(3)hi(l+1)=σ⎝⎛j∈N(i)∑αij(l)zj(l)⎠⎞,(4)
说明:
首先,您可以大致了解如何GATLayer在DGL中实现模块。在本节中,上面的四个方程式一次分解为一个。
import torch
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
# equation (1)
self.fc = nn.Linear(in_dim, out_dim, bias=False)
# equation (2)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e': F.leaky_relu(a)}
def message_func(self, edges):
# message UDF for equation (3) & (4)
return {'z': edges.src['z'], 'e': edges.data['e']}
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h': h}
def forward(self, h):
# equation (1)
z = self.fc(h)
self.g.ndata['z'] = z
# equation (2)
self.g.apply_edges(self.edge_attention)
# equation (3) & (4)
self.g.update_all(self.message_func, self.reduce_func)
return self.g.ndata.pop('h')
z i ( l ) = W ( l ) h i ( l ) , ( 1 ) z_i^{(l)}=W^{(l)}h_i^{(l)},(1) zi(l)=W(l)hi(l),(1)
第一个显示线性变换。这很常见,可以使用在Pytorch中轻松实现torch.nn.Linear。
e i j ( l ) = LeakyReLU ( a ⃗ ( l ) T ( z i ( l ) ∣ z j ( l ) ) ) , ( 2 ) e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}|z_j^{(l)})),(2) eij(l)=LeakyReLU(a(l)T(zi(l)∣zj(l))),(2)
非标准化注意力得分 e i j e_{ij} eij使用相邻节点的嵌入来计算 i i i 和 j j j。这表明注意力得分可以看作是边缘数据,可以由apply_edgesAPI 计算得出 。的参数apply_edges是Edge UDF,其定义如下:
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e' : F.leaky_relu(a)}
在这里,点积与可学习的权重向量 a ( l ) ⃗ \vec{a^{(l)}} a(l)使用PyTorch的线性变换再次实现attn_fc。需要注意的是apply_edges意志批次都在同一个张量的边缘数据,所以 cat,attn_fc这里是平行的所有边应用。
α i j ( l ) = exp ( e i j ( l ) ) ∑ k ∈ N ( i ) exp ( e i k ( l ) ) , ( 3 ) h i ( l + 1 ) = σ ( ∑ j ∈ N ( i ) α i j ( l ) z j ( l ) ) , ( 4 ) \alpha_{ij}^{(l)}=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},(3)\\ h_i^{(l+1)}=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),(4) αij(l)=∑k∈N(i)exp(eik(l))exp(eij(l)),(3)hi(l+1)=σ⎝⎛j∈N(i)∑αij(l)zj(l)⎠⎞,(4)
与GCN类似,update_allAPI用于触发所有节点上的消息传递。消息函数发出两个张量:z 源节点的转换嵌入和e每个边缘上的非标准化注意力得分。然后reduce函数执行两项任务:
这两个任务都首先从邮箱中获取数据,然后在dim=1批量处理邮件的第二个维度()上对其进行操作。
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h' : h}
类似于ConvNet中的多个渠道,GAT引入了多头关注,以丰富模型功能并稳定学习过程。每个关注头都有自己的参数,它们的输出可以通过两种方式合并:
concatenation : h i ( l + 1 ) = ∣ ∣ k = 1 K σ ( ∑ j ∈ N ( i ) α i j k W k h j ( l ) ) \text{concatenation}: h^{(l+1)}_{i} =||_{k=1}^{K}\sigma\left(\sum_{j\in \mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right) concatenation:hi(l+1)=∣∣k=1Kσ⎝⎛j∈N(i)∑αijkWkhj(l)⎠⎞
要么
average : h i ( l + 1 ) = σ ( 1 K ∑ k = 1 K ∑ j ∈ N ( i ) α i j k W k h j ( l ) ) \text{average}: h_{i}^{(l+1)}=\sigma\left(\frac{1}{K}\sum_{k=1}^{K}\sum_{j\in\mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right) average:hi(l+1)=σ⎝⎛K1k=1∑Kj∈N(i)∑αijkWkhj(l)⎠⎞
哪里 K K K是头数。您可以将串联用于中间层,将平均值用于最后一层。
将上面定义的单头GATLayer用作以下内容的构建基块MultiHeadGATLayer:
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATLayer(g, in_dim, out_dim))
self.merge = merge
def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat':
# concat on the output feature dimension (dim=1)
return torch.cat(head_outs, dim=1)
else:
# merge using average
return torch.mean(torch.stack(head_outs))
现在,您可以定义一个两层的GAT模型。
class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__()
self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
# Be aware that the input dimension is hidden_dim*num_heads since
# multiple head outputs are concatenated together. Also, only
# one attention head in the output layer.
self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
def forward(self, h):
h = self.layer1(h)
h = F.elu(h)
h = self.layer2(h)
return h
然后,我们使用DGL的内置数据模块加载Cora数据集。
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx
def load_cora_data():
data = citegrh.load_cora()
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.BoolTensor(data.train_mask)
g = data.graph
# add self loop
g.remove_edges_from(nx.selfloop_edges(g))
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
return g, features, labels, mask
训练循环与GCN教程中的完全相同。
import time
import numpy as np
g, features, labels, mask = load_cora_data()
# create the model, 2 heads, each head has hidden size 8
net = GAT(g,
in_dim=features.size()[1],
hidden_dim=8,
out_dim=7,
num_heads=2)
# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# main loop
dur = []
for epoch in range(30):
if epoch >= 3:
t0 = time.time()
logits = net(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur)))
out:
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3257: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
Epoch 00000 | Loss 1.9462 | Time(s) nan
Epoch 00001 | Loss 1.9456 | Time(s) nan
Epoch 00002 | Loss 1.9449 | Time(s) nan
Epoch 00003 | Loss 1.9442 | Time(s) 0.2926
Epoch 00004 | Loss 1.9435 | Time(s) 0.2970
Epoch 00005 | Loss 1.9428 | Time(s) 0.2960
Epoch 00006 | Loss 1.9421 | Time(s) 0.2944
Epoch 00007 | Loss 1.9414 | Time(s) 0.2955
Epoch 00008 | Loss 1.9406 | Time(s) 0.2947
Epoch 00009 | Loss 1.9399 | Time(s) 0.2945
Epoch 00010 | Loss 1.9391 | Time(s) 0.2945
Epoch 00011 | Loss 1.9384 | Time(s) 0.2955
Epoch 00012 | Loss 1.9376 | Time(s) 0.2951
Epoch 00013 | Loss 1.9368 | Time(s) 0.2946
Epoch 00014 | Loss 1.9360 | Time(s) 0.2955
Epoch 00015 | Loss 1.9351 | Time(s) 0.2954
Epoch 00016 | Loss 1.9343 | Time(s) 0.2954
Epoch 00017 | Loss 1.9334 | Time(s) 0.2953
Epoch 00018 | Loss 1.9325 | Time(s) 0.2959
Epoch 00019 | Loss 1.9317 | Time(s) 0.2964
Epoch 00020 | Loss 1.9307 | Time(s) 0.2963
Epoch 00021 | Loss 1.9298 | Time(s) 0.2968
Epoch 00022 | Loss 1.9289 | Time(s) 0.2968
Epoch 00023 | Loss 1.9279 | Time(s) 0.2966
Epoch 00024 | Loss 1.9269 | Time(s) 0.2964
Epoch 00025 | Loss 1.9259 | Time(s) 0.2955
Epoch 00026 | Loss 1.9249 | Time(s) 0.2955
Epoch 00027 | Loss 1.9238 | Time(s) 0.2944
Epoch 00028 | Loss 1.9228 | Time(s) 0.2937
Epoch 00029 | Loss 1.9217 | Time(s) 0.2927
下表总结了GAT论文中报告并通过DGL实现获得的Cora模型性能 。
Model | Accuracy |
---|---|
GCN (paper) | 81.4±0.5 |
GCN (dgl) | 82.05±0.33 |
GAT (paper) | 83.0±0.7 |
GAT (dgl) | 83.69±0.529 |
我们的模型学到了什么样的注意力分布?
因为注意体重 a i j a_{ij} aij与边缘相关联,您可以通过为边缘着色来形象化它。在下面,您可以选择Cora的一个子图,并绘制最后一个的注意权重GATLayer。节点根据其标签进行着色,而边缘根据注意权重的大小进行着色,这可以通过右侧的色条来参考。
您可以看到该模型似乎学习了不同的注意力权重。要更全面地了解分布,请测量注意力分布的熵。对于任何节点 i i i, { α i j } j ∈ N ( i ) \{\alpha_{ij}\}_{j\in\mathcal{N}(i)} {αij}j∈N(i)通过以下公式给出的熵在其所有邻居上形成离散的概率分布
H ( α i j j ∈ N ( i ) ) = − ∑ j ∈ N ( i ) α i j log α i j H({\alpha_{ij}}_{j\in\mathcal{N}(i)})=-\sum_{j\in\mathcal{N}(i)} \alpha_{ij}\log\alpha_{ij} H(αijj∈N(i))=−j∈N(i)∑αijlogαij
熵低意味着集中度高,反之亦然。熵为0表示所有注意力都集中在一个源节点上。均匀分布具有最高的熵 log ( N ( i ) ) \log(\mathcal{N}(i)) log(N(i))。理想情况下,您希望看到模型学习较低熵的分布(即,一个或两个邻居比其他邻居重要得多)。
注意,由于节点可以具有不同的度数,因此最大熵也将不同。因此,您可以绘制整个图中所有节点的熵值的聚合直方图。以下是每个注意头学习到的注意直方图。
作为参考,以下是所有节点均具有统一注意力权重分布的直方图。
可以看到,学习到的注意力值非常类似于均匀分布 (即,所有邻居都同等重要)。这部分解释了为什么GAT在Cora上的性能接近GCN的性能(根据作者的报告结果,100次运行的平均准确度差异小于2%)。注意并不重要,因为它的区别不大。
*这是否意味着注意力机制没有用?*并不是!另一个不同的数据集表现出完全不同的模式,如下所示。
此处使用的PPI数据集包括 24对应于不同人体组织的图形。节点最多可以121 标签的种类,因此节点的标签表示为大小的二进制张量 121。任务是预测节点标签。
采用 20 训练图 2 用于验证和 2 进行测试。每个图的平均节点数为2372。每个节点都有50由位置基因集,基序基因集和免疫特征组成的特征。至关重要的是,在训练过程中完全看不见测试图,这种设置称为“归纳学习”。
比较GAT和GCN的性能 10 随机运行此任务,并在验证集上使用超参数搜索来找到最佳模型。
Model | F1 Score(micro) |
---|---|
GAT | 0.975±0.006 |
GCN | 0.509±0.025 |
Paper | 0.973±0.002 |
上表是该实验的结果,您可以使用micro F1分数来评估模型性能。
Note:
以下是F1分数的计算过程:
p r e c i s i o n = ∑ t = 1 n T P ( t ) ∑ t = 1 n T P ( t ) + F P ( t ) ) precision=\frac{\sum_{t=1}^{n}TP(t)}{\sum_{t=1}^{n}TP(t)+FP(t))} precision=∑t=1nTP(t)+FP(t))∑t=1nTP(t)
r e c a l l = ∑ t = 1 n T P ( t ) ∑ t = 1 n T P ( t ) + F N ( t ) ) recall=\frac{\sum_{t=1}^{n}TP(t)}{\sum_{t=1}^{n}TP(t)+FN(t))} recall=∑t=1nTP(t)+FN(t))∑t=1nTP(t)
F 1 m i c r o = 2 p r e c i s i o n ∗ r e c a l l p r e c i s i o n + r e c a l l F1_{micro}=2\frac{precision*recall}{precision+recall} F1micro=2precision+recallprecision∗recall
T P t TP_t TPt 表示同时具有和预计具有标签的节点数 t t t
F P t FP_t FPt 表示没有但预计具有标签的节点数 t t t
F N t FN_t FNt 代表标记为 t t t 但与其他人一样预测。
n n n 是标签数,即 121 121 121 就我们而言。
在训练过程中,BCEWithLogitsLoss用作损失功能。GAT和GCN的学习曲线如下所示;显而易见的是,与GCN相比,GAT的显着性能优势。
与以前一样,您可以通过显示节点式注意熵的直方图来对学习到的注意事项进行统计理解。以下是不同注意力层学习的注意力直方图。
在第1层中学习到的注意力:
在第2层中学习到的注意力:
在最后一层学到的注意力:
再次,与均匀分布比较:
显然,GAT确实学习了敏锐的关注权重!各层上也有清晰的图案:一层越多,注意力越集中。
与Cora数据集的GAT增益极少不同,对于PPI,与GAT论文相比,GPI与其他GNN变体之间存在显着的性能差距(至少20%),并且两者之间的注意力分布明显不同。尽管这值得进一步研究,但一个直接的结论是,GAT的优势可能更多在于处理具有更复杂邻域结构的图形的能力。
到目前为止,您已经了解了如何使用DGL来实现GAT。缺少一些遗漏的详细信息,例如退出,跳过连接和超参数调整,这些实践不涉及DGL相关概念。有关更多信息,请查看完整示例。
脚本的总运行时间:(0分钟15.065秒)
下载脚本:9_gat.py
下载脚本:9_gat.ipynb