论文:Self-Attention Graph Pooling
作者:Junhyun Lee, Inyeop Lee, Jaewoo Kang
韩国首尔高丽大学计算机科学与工程系
来源:ICML 2019
论文链接:
Arxiv: https://arxiv.org/abs/1904.08082
代码地址:https://github.com/inyeoplee77/SAGPool
本文作者提出一种新的基于self-attention机制的图池化方法SAGPool,方法充分考虑了节点的特征和图的拓扑结构。在图分类评测任务上SAGPool效果拔群。
SAGPool具有前几种方法的优点:分层池化,同时考虑节点特征和图的拓扑结构(因为利用图卷积得到self-attention分数),合理的复杂度,以及端到端表示学习。SAGPool是第一个使用self-attention进行图池化处理并实现高性能的方法。SAGPool参数量一致,不用考虑输入图的大小。
目前,图池化的方法比图卷积的方法要少,而现存的基于池化的方法存在一些问题:
文中提出了SAGPool,这是一种基于层次图池化的Self-Attention Graph方法。
为什么说文中的attention机制是一种self-attention呢?和GAT中的marsked attention有什么区别呢?
self-attention是一种Global graph attention,会将注意力分配到图中所有的节点上,直接计算图结构中任意两个节点之间的关系,一步到位地获取图结构的全局几何特征。
self−attention利用了attention机制,分三个阶段进行计算:
通过self-attention注意力机制可以计算任意两个样本的关系,一个样本可以用其他所有样本来表示,但是存在一些问题:
(1)基于空间相似假设,一个样本应与一定范围内的样本关系较密切
(2)样本较多的时候,计算量非常大。
为了解决这上述问题,GAT中使用了一种 masked attention 的方法:对于一个样本来说只利用邻域内的样本计算注意力系数和新的表示,即仅将注意力分配到节点的一阶邻居节点集上。
图数据池化方法可以分为以下三类:基于拓扑的池化、全局池化和分层池化。
基于拓扑的池化主要考虑了图的结构特征。早期的工作使用的是图的粗化算法,而不是使用神经网络。谱聚类算法利用特征分解得到粗化图。然而,由于特征分解的时间复杂度过大问题,需要一些替代方法:
(1) Weighted graph cuts without eigenvectors a multilevel approach, 2007
(2)在最近的GNN模型中Graclus被用作池化模块:
与基于拓扑的池化方法不同,全局池化方法考虑了图的属性特征。全局池化方法使用求和或神经网络对每个层中所有节点的表示进行一次性pool:
(1)Neural message passing for quantum chemistry,2017)
将GNNs视为消息传递方案,提出了一种图分类的通用框架,利用Set2Set方法可以获得整个图的表示。
(2)An end-to-end deep learning architecture for graph classification,AAAI 2018
也叫SortPool,它根据图的结构对节点的embeddings进行排序,并将排序后的embeddings传递给下一层。
全局池化方法没有学习对捕获图结构信息至关重要的层次表示。分层池化方法的主要动机是建立一个能够学习每一层中基于特征或拓扑的节点分配模型:
(1)[DIFFPOOL] Hierarchical Graph Representation Learning with Differentiable Pooling,NeurIPS 2018
具体细节,可以参考另一篇博文:[DIFFPOOL 图分类] - Hierarchical Graph Representation Learning with Differentiable Pooling NeurIPS 2018](https://blog.csdn.net/yyl424525/article/details/103307795)
(2)Graph u-net,ICML 2019
gPool实现了与DiffPool相当的性能。gPool需要 O ( ∣ V ∣ + ∣ E ∣ ) O(|V| + |E|) O(∣V∣+∣E∣)的空间复杂度,而DiffPool需要 O ( k ∣ V ∣ 2 ) O(k|V|^2) O(k∣V∣2),其中 V , E , k V,E,k V,E,k分别表示顶点数、边数和池化比率。gPool使用一个可学习的向量 p p p来计算投影分数,然后使用这些分数来从图中选择排名最高的节点保留下来。
文中是基于Graph U-Nets的,因此必须先了解,可以参考另一篇博文:Graph U-Nets
为了进一步改进图池化方法,文中提出了SAGPool,它可以使用图的特征和拓扑结构信息来产生具有合理的时间和空间复杂度的层次表示。
SAGPool的关键在于它使用GNN来提供self-attention分数。
使用注意力机制可以关注更重要的特征。self-attention,通常被称为intra-attention,关注的特征是注意力本身。SAGPool利用图卷积的方法得到self-attention分数。例如,使用Kipf的图卷积公式,则self-attention分数 Z ∈ R N × 1 Z \in \mathbb{R}^{N \times 1} Z∈RN×1根据如下计算:
Z = σ ( D ~ − 1 2 A ~ D ~ − 1 2 X Θ a t t ) (3) \tag{3} Z=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} X \Theta_{a t t}\right) Z=σ(D~−21A~D~−21XΘatt)(3)
因为利用公式里融合了A和X的图卷积得到self-attention分数,所以这种池化的结果是基于图的特征和拓扑的。
SAGPool采用了gPool中的节点选择方法,保留了输入图的一部分节点:
i d x = top-rank ( Z , ⌈ k N ⌉ ) , Z m a s k = Z i d x (4) \tag{4} \mathrm{idx}=\operatorname{top-rank}(Z,\lceil k N\rceil), \quad Z_{m a s k}=Z_{\mathrm{idx}} idx=top-rank(Z,⌈kN⌉),Zmask=Zidx(4)
池化部分就是根据idx对特征和结构进行topK的选择了:
X ′ = X i d x , : , X o u t = X ′ ⊙ Z m a s k , A o u t = A i d x , i d x (5) \tag{5} X^{\prime}=X_{\mathrm{idx}, \mathrm{:}} \quad, \quad X_{o u t}=X^{\prime} \odot Z_{mask} \quad , \quad A_{o u t}=A_{\mathrm{idx}, \mathrm{idx}} X′=Xidx,:,Xout=X′⊙Zmask,Aout=Aidx,idx(5)
SAGPool中使用图卷积的主要原因是为了获得拓扑结构和节点特征。可以使用不同的GNN代替GCN(其他的如GAT、GraphSAGE等),所以计算计算注意力分数 Z ∈ R N × 1 Z \in \mathbb{R}^{N \times 1} Z∈RN×1的公式可以泛化为:
Z = σ ( GNN ( X , A ) ) (6) \tag{6} Z=\sigma(\operatorname{GNN}(X, A)) Z=σ(GNN(X,A))(6)
计算注意力分数,不仅可以使用相邻节点,也可以使用多跳连接的节点。可以使用添加改变邻接矩阵形式扩展边,堆叠多层GNN层,使用多个注意力分数的平均值等方法来达到这个目的。
以一个连接两跳的节点为例。
(1)添加邻接矩阵的平方: SAGPool augmentation \text { SAGPool}_{\text {augmentation}} SAGPoolaugmentation
式(7)使用了两跳连接,该连接涉及边的扩展,允许两跳节点的间接聚合。添加邻接矩阵的平方相当于在两跳邻居之间创建了边:
Z = σ ( GNN ( X , A + A 2 ) ) (7) \tag{7} Z=\sigma\left(\operatorname{GNN}\left(X, A+A^{2}\right)\right) Z=σ(GNN(X,A+A2))(7)
(2)叠加两层GNN层: SAGPool serial \text { SAGPool}_{\text {serial}} SAGPoolserial
式(8)使用了两跳连接,该连接涉及GNN层的堆叠,允许两跳节点的间接聚合。在这种情况下,SAGPool层的非线性和参数数量将增加:
Z = σ ( GNN 2 ( σ ( GNN 1 ( X , A ) ) , A ) ) (8) \tag{8} Z=\sigma\left(\operatorname{GNN}_{2}\left(\sigma\left(\operatorname{GNN}_{1}(X, A)\right), A\right)\right) Z=σ(GNN2(σ(GNN1(X,A)),A))(8)
公式(7)和公式(8)可以应用到更多跳的连接上。
(3)取多重注意力分数的平均值,类似于Multi-head GAT: SAGPool parallel \text { SAGPool}_{\text {parallel}} SAGPoolparallel
M M M个GNNs平均注意力分值:
Z = 1 M ∑ m σ ( G N N m ( X , A ) ) (9) \tag{9} Z=\frac{1}{M} \sum_{m} \sigma\left(\mathrm{GNN}_{m}(X, A)\right) Z=M1m∑σ(GNNm(X,A))(9)
文中分别将公式(7),(8),(9)中的模型称为 SAGPool augmentation \text { SAGPool}_{\text {augmentation}} SAGPoolaugmentation, SAGPool serial \text { SAGPool}_{\text {serial}} SAGPoolserial和 SAGPool parallel \text { SAGPool}_{\text {parallel}} SAGPoolparallel。
看看SAGPool的源代码:
$$
from torch_geometric.nn import GCNConv
from torch_geometric.nn.pool.topk_pool import topk,filter_adj
from torch.nn import Parameter
import torch
class SAGPool(torch.nn.Module):
def init(self,in_channels,ratio=0.8,Conv=GCNConv,non_linearity=torch.tanh):
super(SAGPool,self).init()
self.in_channels = in_channels
self.ratio = ratio # 论文中的参数k
self.score_layer = Conv(in_channels,1) # 论文中的Z
self.non_linearity = non_linearity
def forward(self, x, edge_index, edge_attr=None, batch=None):
if batch is None:
batch = edge_index.new_zeros(x.size(0))
#x = x.unsqueeze(-1) if x.dim() == 1 else x
score = self.score_layer(x,edge_index).squeeze()
perm = topk(score, self.ratio, batch) # topk选择最大的几个
x = x[perm] * self.non_linearity(score[perm]).view(-1, 1) # mask
batch = batch[perm]
edge_index, edge_attr = filter_adj( # 选择子图结构特征
edge_index, edge_attr, perm, num_nodes=score.size(0))
return x, edge_index, edge_attr, batch, perm
$$
上文提到了GNN可以有很多种,如GAT、GraphSAGE等,本文还是用了Kipf的卷积:
h ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 h ( l ) Θ ) (10) \tag{10} h^{(l+1)}=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} h^{(l)} \Theta\right) h(l+1)=σ(D~−21A~D~−21h(l)Θ)(10)
受JK-net架构(Representation learning on graphs with jumping knowledge networks,2018;Towards sparse hierarchical graph classifiers,2018)的启发,提出了一种readout层,该层聚合节点特征以形成固定大小的表示。readout层的输出特征如下:
s = 1 N ∑ i = 1 N x i ∥ max i = 1 N x i (11) \tag{11} s=\frac{1}{N} \sum_{i=1}^{N} x_{i} \| \max _{i=1}^{N} x_{i} s=N1i=1∑Nxi∥i=1maxNxi(11)
代码为:
f r o m t o r c h g e o m e t r i c . n n i m p o r t g l o b a l m e a n p o o l a s g a p , g l o b a l m a x p o o l a s g m p x 1 = t o r c h . c a t ( [ g m p ( x , b a t c h ) , g a p ( x , b a t c h ) ] , d i m = 1 ) from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) fromtorchgeometric.nnimportglobalmeanpoolasgap,globalmaxpoolasgmpx1=torch.cat([gmp(x,batch),gap(x,batch)],dim=1)
实现了(An end-to-end deep learning architecture for graph classification,AAAI 2018)中提出的全局池化架构。
实现了(Towards sparse hierarchical graph classifiers,2018)分层池化架构。
对于具体的模型架构,本文使用了两种Global pooling architecture和Hierarchical pooling architecture:
在图分类任务中,评估了全局池化和分层池化方法。
选取了5个图的数量较大的数据集( > 1 k > 1k >1k):
很难确定全局池化结构或层次池化结构是否完全有利于图形分类。因为全局池化结构 P O O L g P O O L_{g} POOLg( SAGPool g \text { SAGPool }_{g} SAGPool g、 SortPool g \text { SortPool }_{g} SortPool g、 Set2Set g \text { Set2Set }_{g} Set2Set g)使信息丢失最小化,因此它在节点较少的数据集(NCI1、NCI109、FRANKENSTEIN)上的性能优于分层池化结构 P O O L h P O O L_{h} POOLh( SAGPool h \text { SAGPool }_{h} SAGPool h、 gPool h \text { gPool }_{h} gPool h、 DiffPool h \text { DiffPool }_{h} DiffPool h)。
但是,分层池化 P O O L h P O O L_{h} POOLh对节点数较多的数据集(D&D和PROTEINS)更有效,因为它能有效地从大规模图中提取有用的信息。因此,使用最适合给定数据的池化结构非常重要。尽管如此,SAGPool在每种架构中通常都表现良好。
和gPool不一样, SAGPool使用一阶近似图的拉普拉斯算子 D ~ − 1 2 A ~ D ~ − 1 2 \tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} D~−21A~D~−21,这使得SAGPool考虑了图的拓扑结构。如表3所示,考虑图的拓扑结构可以提高性能。此外,图的Laplacian算子不需要重新计算,因为它在前一个图卷积层中也使用了,可以预先计算。
虽然SAGPool具有与gPool相同的参数,但它在图分类任务中表现出了更优异的性能。
结果表明,SAGPool在总体上表现良好,在D&D和PROTEINS方面表现尤为突出。在实验中,SAGPool在所有的数据集上都优于分层池化的方法。
使用稀疏矩阵操作图数据对于GNNs来说非常重要,因为邻接矩阵通常是稀疏的。
用稠密矩阵计算图卷积时,乘法 A X AX AX的计算复杂度为 O ( ∣ V ∣ 2 ) O(|V|^2) O(∣V∣2),其中 A A A为邻接矩阵, X X X为节点特征矩阵, V V V为顶点。如(Towards sparse hierarchical graph classifiers,2018)所述,密集矩阵池化会导致内存效率问题。
如果在同一操作中使用稀疏矩阵,则计算复杂度降低到 O ( ∣ E ∣ ) O(|E|) O(∣E∣),其中 E E E表示边。由于SAGPool是一种稀疏池化方法,使用稀疏实现可以降低计算复杂度,而DiffPool是一种密集池化方法,计算复杂度较高。
稀疏性也影响空间复杂性。因为SAGPool使用GNN来获取注意力分数,所以SAGPool需要 O ( ∣ V ∣ + ∣ E ∣ ) O(|V |+|E|) O(∣V∣+∣E∣)的稀疏池化存储空间,而稠密池化方法需要 O ( ∣ V ∣ 2 ) O(|V|^2) O(∣V∣2)。
在DiffPool中,由于GNN产生了assignment矩阵S,因此在构建模型时必须定义cluster的大小。根据参考方法的实现,cluster的大小必须与最大节点数成比例。DiffPool的这些要求会导致两个问题。
在SAGPool中,参数的数量与cluster的大小无关。此外,可以根据输入节点的数量更改cluster大小。
为了研究SAGPool方法的潜力,在两个数据集上评估了SAGPool的变种。可以用以下操作修改SAGPool: