论文:Self-Attention Graph Pooling
作者:Junhyun Lee, Inyeop Lee, Jaewoo Kang
韩国首尔高丽大学计算机科学与工程系
来源:ICML 2019
论文链接:
Arxiv: https://arxiv.org/abs/1904.08082
github链接:https://github.com/inyeoplee77/SAGPool
近年来,人们提出了将深度学习应用于图数据等结构化数据的方法。研究工作集中在将CNN泛化到图数据,重新定义图数据的卷积操作和downsampling(池化)操作。将卷积运算推广到图上的方法已被证明了可以提高性能,并得到了广泛的应用。然而,将downsampling应用于图数据的方法仍然难以实现,还有改进的空间。此文提出了一种基于self-attention的图池化方法。使用图卷积的self-attention使得化池化方法同时考虑了节点特征和图的拓扑结构。实验中,为了确保公平的比较,将现有的池化方法和文中提出的SAGPool方法使用了相同的训练流程和模型架构。实验结果表明,该方法使用合理数量的参数,在基准数据集上获得了较好的图分类性能。
目前,图池化的方法比图卷积的方法要少,而现存的基于池化的方法存在一些问题:
文中提出了SAGPool,这是一种基于层次图池化的Self-Attention Graph方法。
简而言之,SAGPool具有前几种方法的优点:分层池化,同时考虑节点特征和图的拓扑结构,合理的复杂度,以及端到端表示学习。SAGPool是第一个使用self-attention进行图池化处理并实现高性能的方法。SAGPool参数量一致,不用考虑输入图的大小。
GNNs因其在图领域中的表现而备受关注。CNN模型中的池化层通过缩小图片的表示的大小来减少参数的数量,从而避免过拟合问题。为了将CNNs推广到图数据上,GNNs使用池化方法是必要的。图数据池化方法可以分为以下三类:基于拓扑的池化、全局池化和分层池化。
早期的工作使用的是图的粗化算法,而不是使用神经网络。谱聚类算法利用特征分解得到粗化图。然而,由于特征分解的时间复杂度过大问题,需要一些替代方法:
(1) Weighted graph cuts without eigenvectors a multilevel approach, 2007
Graclus计算了给定图无特征向量的的聚类,因为一般谱聚类目标和加权核k-means目标之间存在数学等价性。
(2)在最近的GNN模型中Graclus被用作池化模块:
与基于拓扑的池化方法不同,全局池化方法考虑了图的特征。全局池化方法使用求和或神经网络pool每个层中所有节点的表示。全局池化方法pool所有的表示,可以处理具有不同结构的图:
(1)Neural message passing for quantum chemistry,2017)
将GNNs视为消息传递方案,提出了一种图分类的通用框架,利用Set2Set方法可以获得整个图的表示。
(2)An end-to-end deep learning architecture for graph classification,AAAI 2018
SortPool:根据图的结构对节点的embeddings进行排序,并将排序后的embeddings传递给下一层。
全局池化方法没有学习对捕获图结构信息至关重要的层次表示。分层池化方法的主要动机是建立一个能够学习每一层中基于特征或拓扑的节点assignment的模型:
(1)[DIFFPOOL] Hierarchical Graph Representation Learning with Differentiable Pooling,NeurIPS 2018
DIFFPOOL是一种可微的图池化方法,能够以端到端的方式学习assignment矩阵: S ( l ) ∈ R n l × n l + 1 S^{(l)} \in \mathbb{R}^{n_{l} \times n_{l+1}} S(l)∈Rnl×nl+1, n l n_l nl表示在第 l l l层的节点数, n l + 1 n_{l+1} nl+1表示在第 l + 1 l+1 l+1层的节点数, n l < n l + 1 n_l < n_{l+1} nl<nl+1,即行数等于第 l l l层的节点数(cluster数),列数代表第 l + 1 l+1 l+1层的节点数(cluster数)。assignment matrix表示第 l l l层的每一个节点到第 l + 1 l+1 l+1层的每一个节点(或cluster)的概率。
具体而言,节点根据下面的公式分到下一层的cluster:
S ( l ) = softmax ( G N N l ( A ( l ) , X ( l ) ) ) S^{(l)}=\operatorname{softmax}\left(\mathrm{GNN}_{l}\left(A^{(l)}, X^{(l)}\right)\right) S(l)=softmax(GNNl(A(l),X(l)))
A ( l + 1 ) = S ( l ) ⊤ A ( l ) S ( l ) (1) \tag{1} A^{(l+1)}=S^{(l) \top} A^{(l)} S^{(l)} A(l+1)=S(l)⊤A(l)S(l)(1)
具体细节,可以参考另一篇博文:[DIFFPOOL 图分类] - Hierarchical Graph Representation Learning with Differentiable Pooling NeurIPS 2018
(2)Graph u-net,ICML 2019
gPool实现了与DiffPool相当的性能。gPool需要 O ( ∣ V ∣ + ∣ E ∣ ) O(|V| + |E|) O(∣V∣+∣E∣)的空间复杂度,而DiffPool需要 O ( k ∣ V ∣ 2 ) O(k|V|^2) O(k∣V∣2),其中 V , E , k V,E,k V,E,k分别表示顶点数、边数和池化比率。gPool使用一个可学习的向量 p p p来计算投影分数,然后使用这些分数来从图中选择排名最高的节点保留下来。投影分数由 p p p与各节点特征的点积得到。分数表示可以保留的节点信息量。下式大致描述了gPool中的池化过程:
y = X ( l ) p ( l ) / ∥ p ( l ) ∥ , id x = top − rank ( y , ⌈ k N ⌉ ) y=X^{(l)} \mathbf{p}^{(l)} /\left\|\mathbf{p}^{(l)}\right\|, \quad \text { id } \mathbf{x}=\operatorname{top}-\operatorname{rank}(y,\lceil k N\rceil) y=X(l)p(l)/∥∥∥p(l)∥∥∥, id x=top−rank(y,⌈kN⌉)
A ( l + 1 ) = A i d x , i d x ( l ) (2) \tag{2} A^{(l+1)}=A_{\mathrm{idx}, \mathrm{idx}}^{(l)} A(l+1)=Aidx,idx(l)(2)
在公式(2)中,投影分数的计算没有考虑图的拓扑结构。
为了进一步改进图池化方法,文中提出了SAGPool,它可以使用图的特征和拓扑结构信息来产生具有合理的时间和空间复杂度的层次表示。
SAGPool的关键在于它使用GNN来提供self-attention分数。
注意力机制在最近的深度学习研究中被广泛应用。这种机制可以使我们能够关注更重要的特征。self-attention,通常被称为intra-attention,关注的特征是注意力本身。SAGPool利用图卷积的方法得到self-attention分数。例如,如果使用Kipf的图卷积公式,则self-attention分数 Z ∈ R N × 1 Z \in \mathbb{R}^{N \times 1} Z∈RN×1根据如下计算:
Z = σ ( D ~ − 1 2 A ~ D ~ − 1 2 X Θ a t t ) (3) \tag{3} Z=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} X \Theta_{a t t}\right) Z=σ(D~−21A~D~−21XΘatt)(3)
因为利用图卷积得到self-attention分数,所以这种池化的结果是基于图的特征和拓扑的。
SAGPool采用了gPool中的节点选择方法,保留了输入图的一部分节点:
i d x = top-rank ( Z , ⌈ k N ⌉ ) , Z m a s k = Z i d x (4) \tag{4} \mathrm{idx}=\operatorname{top-rank}(Z,\lceil k N\rceil), \quad Z_{m a s k}=Z_{\mathrm{idx}} idx=top-rank(Z,⌈kN⌉),Zmask=Zidx(4)
输入图由图1中标记为masking的操作处理。
X ′ = X i d x , : , X o u t = X ′ ⊙ Z m a s k , A o u t = A i d x , i d x (5) \tag{5} X^{\prime}=X_{\mathrm{idx}, \mathrm{:}} \quad, \quad X_{o u t}=X^{\prime} \odot Z_{m a s k} \quad , \quad A_{o u t}=A_{\mathrm{idx}, \mathrm{idx}} X′=Xidx,:,Xout=X′⊙Zmask,Aout=Aidx,idx(5)
SAGPool中使用图卷积的主要原因是为了获得拓扑结构和节点特征。如果GNNs以节点特征和邻接矩阵为输入,则可以用式(3)Kipf的图卷积公式代替GNNs的各种公式。计算注意力分数 Z ∈ R N × 1 Z \in \mathbb{R}^{N \times 1} Z∈RN×1的广义公式如下:
Z = σ ( GNN ( X , A ) ) (6) \tag{6} Z=\sigma(\operatorname{GNN}(X, A)) Z=σ(GNN(X,A))(6)
计算注意力分数,不仅可以使用相邻节点,也可以使用多跳连接的节点。可以使用添加改变邻接矩阵形式扩展边,堆叠多层GNN层,使用多个注意力分数的平均值等方法来达到这个目的。
以一个连接两跳的节点为例。
(1)添加邻接矩阵的平方: SAGPool augmentation \text { SAGPool}_{\text {augmentation}} SAGPoolaugmentation
式(7)使用了两跳连接,该连接涉及边的扩展,允许两跳节点的间接聚合。添加邻接矩阵的平方相当于在两跳邻居之间创建了边:
Z = σ ( GNN ( X , A + A 2 ) ) (7) \tag{7} Z=\sigma\left(\operatorname{GNN}\left(X, A+A^{2}\right)\right) Z=σ(GNN(X,A+A2))(7)
(2)叠加两层GNN层: SAGPool serial \text { SAGPool}_{\text {serial}} SAGPoolserial
式(8)使用了两跳连接,该连接涉及GNN层的堆叠,允许两跳节点的间接聚合。在这种情况下,SAGPool层的非线性和参数数量将增加:
Z = σ ( GNN 2 ( σ ( GNN 1 ( X , A ) ) , A ) ) (8) \tag{8} Z=\sigma\left(\operatorname{GNN}_{2}\left(\sigma\left(\operatorname{GNN}_{1}(X, A)\right), A\right)\right) Z=σ(GNN2(σ(GNN1(X,A)),A))(8)
公式(7)和公式(8)可以应用到更多跳的连接上。
(3)取多个注意力分数的平均值: SAGPool parallel \text { SAGPool}_{\text {parallel}} SAGPoolparallel
另一种方法取多个注意力得分的平均值。 M M M个GNNs平均注意力分值如下:
Z = 1 M ∑ m σ ( G N N m ( X , A ) ) (9) \tag{9} Z=\frac{1}{M} \sum_{m} \sigma\left(\mathrm{GNN}_{m}(X, A)\right) Z=M1m∑σ(GNNm(X,A))(9)
文中分别将公式(7),(8),(9)中的模型称为 SAGPool augmentation \text { SAGPool}_{\text {augmentation}} SAGPoolaugmentation, SAGPool serial \text { SAGPool}_{\text {serial}} SAGPoolserial和 SAGPool parallel \text { SAGPool}_{\text {parallel}} SAGPoolparallel。
根据(Troubling trends in machine learning scholarship,2018)的观点,如果对一个模型进行了多次修改,那么就很难确定哪些修改有助于提高性能。为了公平的比较,文中采用了SortPool(AAAI 2018)和(Towards sparse hierarchical graph classifiers,2018)中的模型架构,并使用相同的架构来比较baseline和文中的方法。
使用Kipf的GCN:
h ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 h ( l ) Θ ) (10) \tag{10} h^{(l+1)}=\sigma\left(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} h^{(l)} \Theta\right) h(l+1)=σ(D~−21A~D~−21h(l)Θ)(10)
受JK-net架构(Representation learning on graphs with jumping knowledge networks,2018;Towards sparse hierarchical graph classifiers,2018)的启发,提出了一种readout层,该层聚合节点特征以形成固定大小的表示。readout层的输出特征如下:
s = 1 N ∑ i = 1 N x i ∥ max i = 1 N x i (11) \tag{11} s=\frac{1}{N} \sum_{i=1}^{N} x_{i} \| \max _{i=1}^{N} x_{i} s=N1i=1∑Nxi∥i=1maxNxi(11)
实现了(An end-to-end deep learning architecture for graph classification,AAAI 2018)中提出的全局池化架构。
实现了(Towards sparse hierarchical graph classifiers,2018)分层池化架构。
在图分类任务中,评估了全局池化和分层池化方法。
选取了5个图的数量较大的数据集( > 1 k > 1k >1k):
DiffPool,gPool和 SAGPool h \text { SAGPool }_{h} SAGPool h使用分层池化架构;
Set2Set, SortPool和 SAGPool g \text { SAGPool }_{g} SAGPool g使用全局池化架构;
对所有的baselines和SAGPool使用相同的超参数搜索策略。超参数如表2所示。
Z = 1 M ∑ m σ ( G N N m ( X , A ) ) (9) \tag{9} Z=\frac{1}{M} \sum_{m} \sigma\left(\mathrm{GNN}_{m}(X, A)\right) Z=M1m∑σ(GNNm(X,A))(9)
在DiffPool中,由于GNN产生了如式(1)所示的assignment矩阵S,因此在构建模型时必须定义cluster的大小。根据参考方法的实现,cluster的大小必须与最大节点数成比例。DiffPool的这些要求会导致两个问题。
在SAGPool中,参数的数量与cluster的大小无关。此外,可以根据输入节点的数量更改cluster大小。
为了研究SAGPool方法的潜力,在两个数据集上评估了SAGPool的变种。可以用以下操作修改SAGPool: