论文:Inductive Representation Learning on Large Graphs 在大图上的归纳表示学习
作者:Hamilton, William L. and Ying, Rex and Leskovec, Jure(斯坦福)
来源:NIPS 2017
论文链接:https://arxiv.org/abs/1706.02216
github链接:https://github.com/williamleif/GraphSAGE
官方介绍链接:http://snap.stanford.edu/graphsage/
此文提出的方法叫GraphSAGE,针对的问题是之前的网络表示学习的transductive,从而提出了一个inductive的GraphSAGE算法。GraphSAGE同时利用节点特征信息和结构信息得到Graph Embedding的映射,相比之前的方法,之前都是保存了映射后的结果,而GraphSAGE保存了生成embedding的映射,可扩展性更强,对于节点分类和链接预测问题的表现也比较突出。
现存的方法需要图中所有的顶点在训练embedding的时候都出现;这些前人的方法本质上是transductive,不能自然地泛化到未见过的顶点。
文中提出了GraphSAGE,是一个inductive的框架,可以利用顶点特征信息(比如文本属性)来高效地为没有见过的顶点生成embedding。
GraphSAGE是为了学习一种节点表示方法,即如何通过从一个顶点的局部邻居采样并聚合顶点特征,而不是为每个顶点训练单独的embedding。
这个算法在三个inductive顶点分类benchmark上超越了那些很强的baseline。文中基于citation和Reddit帖子数据的信息图中对未见过的顶点分类,实验表明使用一个PPI(protein-protein interactions)多图数据集,算法可以泛化到完全未见过的图上。
在大型图中,节点的低维向量embedding被证明了作为各种各样的预测和图分析任务的特征输入是非常有用的。顶点embedding最基本的基本思想是使用降维技术从高维信息中提炼一个顶点的邻居信息,存到低维向量中。这些顶点嵌入之后会作为后续的机器学习系统的输入,解决像顶点分类、聚类、链接预测这样的问题。
GCN虽然能提取图中顶点的embedding,但是存在一些问题:
GCN的基本思想: 把一个节点在图中的高纬度邻接信息降维到一个低维的向量表示。
GCN的优点: 可以捕捉graph的全局信息,从而很好地表示node的特征。
GCN的缺点: Transductive learning的方式,需要把所有节点都参与训练才能得到node embedding,无法快速得到新node的embedding。
前人的工作专注于从一个固定的图中对顶点进行表示,很多现实中的应用需要很快的对未见过的顶点或是全新的图(子图)生成embedding。这种推断的能力对于高吞吐的机器学习系统来说很重要,这些系统都运作在不断演化的图上,而且时刻都会遇到未见过的顶点(比如Reddit上的帖子(posts),Youtube上的用户或视频)。因此,一种inductive的学习方法比transductive的更重要。
transductive learning得到新节点的表示的难处:
要想得到新节点的表示,需要让新的graph或者subgraph去和已经优化好的node embedding去“对齐(align)”。然而每个节点的表示都是受到其他节点的影响,因此添加一个节点,意味着许许多多与之相关的节点的表示都应该调整。这会带来极大的计算开销,即使增加几个节点,也要完全重新训练所有的节点。
GraphSAGE基本思路:
既然新增的节点,一定会改变原有节点的表示,那么为什么一定要得到每个节点的一个固定的表示呢?何不直接学习一种节点的表示方法。去学习一个节点的信息是怎么通过其邻居节点的特征聚合而来的。 学习到了这样的“聚合函数”,而我们本身就已知各个节点的特征和邻居关系,我们就可以很方便地得到一个新节点的表示了。
GCN等transductive的方法,学到的是每个节点的一个唯一确定的embedding; 而GraphSAGE方法学到的node embedding,是根据node的邻居关系的变化而变化的,也就是说,即使是旧的node,如果建立了一些新的link,那么其对应的embedding也会变化,而且也很方便地学到。
本文出自斯坦福PinSAGE的理论篇,关于第一个基于GCN的工业级推荐系统的PinSAGE可以看这篇:
[PinSage] Graph Convolutional Neural Networks for Web-Scale Recommender Systems 论文详解KDD2018
GraphSAGE算法在概念上与以前的节点embedding方法、一般的图形学习监督方法以及最近将卷积神经网络应用于图形结构化数据的进展有关。
一些node embedding方法使用随机游走的统计方法和基于矩阵分解学习目标学习低维的embeddings
这些embedding算法直接训练单个节点的节点embedding,本质上是transductive,而且需要大量的额外训练(如随机梯度下降)使他们能预测新的顶点。
此外,Yang et al.的Planetoid-I算法,是一个inductive的基于embedding的半监督学习算法。然而,Planetoid-I在推断的时候不使用任何图结构信息,而在训练的时候将图结构作为一种正则化的形式。
不像前面的这些方法,本文利用特征信息来训练可以对未见过的顶点生成embedding的模型。
Graph kernel
除了节点嵌入方法,还有大量关于图结构数据的监督学习的文献。这包括各种各样的基于内核的方法,其中图的特征向量来自不同的图内核(参见Weisfeiler-lehman graph kernels和其中的引用)。
一些神经网络方法用于图结构上的监督学习,本文的方法在概念上受到了这些算法的启发
然而,这些以前的方法是尝试对整个图(或子图)进行分类的,但是本文的工作的重点是为单个节点生成有用的表示。
近年来,提出了几种用于图上学习的卷积神经网络结构
这些方法中的大多数不能扩展到大型图,或者设计用于全图分类(或者两者都是)。
GraphSAGE可以看作是对transductive的GCN框架对inductive下的扩展。
GraphSAGE的核心:GraphSAGE不是试图学习一个图上所有node的embedding,而是学习一个为每个node产生embedding的映射。
文中不是对每个顶点都训练一个单独的embeddding向量,而是训练了一组aggregator functions,这些函数学习如何从一个顶点的局部邻居聚合特征信息(见图1)。每个聚合函数从一个顶点的不同的hops或者说不同的搜索深度聚合信息。测试或是推断的时候,使用训练好的系统,通过学习到的聚合函数来对完全未见过的顶点生成embedding。
GraphSAGE 是Graph SAmple and aggreGatE的缩写,其运行流程如上图所示,可以分为三个步骤:
文中设计了无监督的损失函数,使得GraphSAGE可以在没有任务监督的情况下训练。实验中也展示了使用完全监督的方法如何训练GraphSAGE。
GraphSAGE的前向传播算法如下,前向传播描述了如何使用聚合函数对节点的邻居信息进行聚合,从而生成节点embedding:
在每次迭代(或搜索深度),顶点从它们的局部邻居聚合信息,并且随着这个过程的迭代,顶点会从越来越远的地方获得信息。PinSAGE使用的前向传播算法和GraphSAGE一样,GraphSAGE是PinSAGE的理论基础。
算法1描述了在整个图上生成embedding的过程,其中
为了将算法1扩展到minibatch环境上,给定一组输入顶点,先采样采出需要的邻居集合(直到深度 K K K),然后运行内部循环(算法1的第三行)(附录A包括了完整的minibatch伪代码)。
GraphSAGE算法从概念上受到测试图同构的经典算法的启发。对于算法1,如果
(1) K = ∣ V ∣ K=|\mathcal{V}| K=∣V∣
(2)设置权重矩阵为单位矩阵
(3)使用一个适当的哈希函数作为一个聚合器(没有非线性)
那么算法1是Weisfeiler-Lehman的实例(WL)同构测试,也称为“naive vertex refinement”[Weisfeiler-lehman graph kernels]。
如果对两个子图通过算法1的特征表示的集合{ z v , ∀ v ∈ V z_v,\forall v \in \mathcal{V} zv,∀v∈V}输出是相同的,那么ML测试就认为这两个子图是同构的。如果用可训练的神经网络聚合器替换了哈希函数,那么GraphSAGE就是WL测试的一个连续近似。当然,文中使用GraphSAGE是为了生成有用的节点表示—而不是测试图的同构。然而,GraphSAGE与经典的WL检验之间的联系为算法设计学习节点邻域的拓扑结构提供了理论基础。
出于对计算效率的考虑,对每个顶点采样一定数量的邻居顶点作为待聚合信息的顶点。设需要的邻居数量,即采样数量为 S S S,若顶点邻居数少于 S S S,则采用有放回的抽样方法,直到采样出 S S S个顶点。若顶点邻居数大于 S S S,则采用无放回的抽样。(即采用有放回的重采样/负采样方法达到 S S S)
当然,若不考虑计算效率,完全可以对每个顶点利用其所有的邻居顶点进行信息聚合,这样是信息无损的。
文中在较大的数据集上实验。因此,统一采样一个固定大小的邻域集,以保持每个batch的计算占用空间是固定的(即 graphSAGE并不是使用全部的相邻节点,而是做了固定size的采样)。
这样固定size的采样,每个节点和采样后的邻居的个数都相同,可以把每个节点和它们的邻居拼成一个batch送到GPU中进行批训练。
在图中顶点的邻居是无序的,所以希望构造出的聚合函数是对称的(即也就是对它输入的各种排列,函数的输出结果不变),同时具有较高的表达能力。 聚合函数的对称性(symmetry property)确保了神经网络模型可以被训练且可以应用于任意顺序的顶点邻居特征集合上。
mean aggregator将目标顶点和邻居顶点的第 k − 1 k-1 k−1层向量拼接起来,然后对向量的每个维度进行求均值的操作,将得到的结果做一次非线性变换产生目标顶点的第 k k k层表示向量。
文中用下面的式子替换算法1中的4行和5行得到GCN的inductive变形:
h v k ← σ ( W ⋅ M E A N ( { h v k − 1 } ∪ { h u k − 1 , ∀ u ∈ N ( v ) } ) ) ( 2 ) \mathbf{h}^k_v \leftarrow \sigma(\mathbf{W} \cdot \mathrm{MEAN}(\lbrace \mathbf{h}^{k-1}_v \rbrace \cup \lbrace \mathbf{h}^{k-1}_u, \forall u \in \mathcal{N}(v) \rbrace)) \qquad(2) hvk←σ(W⋅MEAN({hvk−1}∪{huk−1,∀u∈N(v)}))(2)
原始第4,5行是
h v k ← σ ( W ⋅ MEAN ( { h v k − 1 } ∪ { h u k − 1 , ∀ u ∈ N ( v ) } ) h v k ← σ ( W k ⋅ CONCAT ( h v k − 1 , h N ( v ) k ) ) \mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W} \cdot \operatorname{MEAN}\left(\left\{\mathbf{h}_{v}^{k-1}\right\} \cup\left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}\right)\right . \\ \mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W}^{k} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{k-1}, \mathbf{h}_{\mathcal{N}(v)}^{k}\right)\right) hvk←σ(W⋅MEAN({hvk−1}∪{huk−1,∀u∈N(v)})hvk←σ(Wk⋅CONCAT(hvk−1,hN(v)k))
文中也测试了一个基于LSTM的复杂的聚合器[Long short-term memory]。和均值聚合器相比,LSTMs有更强的表达能力。但是,LSTMs不是symmetric的,也就是说不具有排列不变性(permutation invariant),因为它们以一个序列的方式处理输入。因此,需要先对邻居节点随机顺序,然后将邻居序列的embedding作为LSTM的输入。
pooling聚合器,它既是对称的,又是可训练的。Pooling aggregator 先对目标顶点的邻居顶点的embedding向量进行一次非线性变换,之后进行一次pooling操作(max pooling or mean pooling),将得到结果与目标顶点的表示向量拼接,最后再经过一次非线性变换得到目标顶点的第k层表示向量。
一个element-wise max pooling操作应用在邻居集合上来聚合信息:
h N ( v ) k = A G G R E G A T E k p o o l = m a x ( { σ ( W p o o l h u k − 1 + b ) , ∀ u ∈ N ( v ) } ) ( 3 ) h v k ← σ ( W k ⋅ CONCAT ( h v k − 1 , h N ( v ) k ) ) \mathbf{h}_{\mathcal{N}(v)}^{k}=\mathrm{AGGREGATE}^{pool}_k = \mathrm{max}(\lbrace \sigma (\mathbf{W}_{pool} \mathbf{h}^{k-1}_{u} + \mathbf{b}), \forall u \in \mathcal{N}(v) \rbrace) \qquad(3) \\ \mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W}^{k} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{k-1}, \mathbf{h}_{\mathcal{N}(v)}^{k}\right)\right) hN(v)k=AGGREGATEkpool=max({σ(Wpoolhuk−1+b),∀u∈N(v)})(3)hvk←σ(Wk⋅CONCAT(hvk−1,hN(v)k))
其中
在定义好聚合函数之后,接下来就是对函数中的参数进行学习。文章分别介绍了无监督学习和监督学习两种方式。
基于图的损失函数倾向于使得相邻的顶点有相似的表示,但这会使相互远离的顶点的表示差异变大:
J G ( z u ) = − log ( σ ( z u T z v ) ) − Q ⋅ E v n ∼ P n ( v ) log ( σ ( − z u T z v n ) ) ( 1 ) J\mathcal{G}(\mathbf{z}_u) = -\log(\sigma(\mathbf{z}^T_u \mathbf{z}_v)) - Q \cdot \mathbb{E}_{v_n \sim P_n(v)} \log(\sigma(-\mathbf{z}^T_u \mathbf{z}_{v_n})) \qquad(1) JG(zu)=−log(σ(zuTzv))−Q⋅Evn∼Pn(v)log(σ(−zuTzvn))(1)
其中
文中输入到损失函数的表示 z u \mathbf{z}_u zu是从包含一个顶点局部邻居的特征生成出来的,而不像之前的那些方法(如DeepWalk),对每个顶点训练一个独一无二的embedding,然后简单进行一个embedding查找操作得到。
无监督损失函数的设定来学习节点embedding 可以供下游多个任务使用。监督学习形式根据任务的不同直接设置目标函数即可,如最常用的节点分类任务使用交叉熵损失函数。
通过前向传播得到节点 u u u的embedding z u z_u zu,然后梯度下降(实现使用Adam优化器) 进行反向传播优化参数 W k \mathbf{W}^{k} Wk和聚合函数内的参数。
这个 W k \mathbf{W}^{k} Wk就是所谓的dynamic embedding的核心,因为保存下来了从节点原始的高维特征生成低维embedding的方式。现在,如果想得到一个点的embedding,只需要输入节点的特征向量,经过卷积(利用已经训练好的 W k \mathbf{W}^{k} Wk以及特定聚合函数聚合neighbor的属性信息),就产生了节点的embedding。
除此以外,还比较了GraphSAGE四个变种 ,并无监督生成embedding输入给LR和端到端有监督。因为,GraphSAGE的卷积变体是一种扩展形式,是Kipf et al. 半监督GCN的inductive版本,称这个变体为GraphSAGE-GCN。
分类器均采用LR
在所有这些实验中,预测在训练期间看不到的节点,在PPI数据集的情况下,实验在完全看不见的图上进行了测试。
前两个实验是在演化的信息图中对节点进行分类,这是一个与高吞吐量生产系统特别相关的任务,该系统经常遇到不可见的数据。
第一个任务是在一个大的引文数据集中预测论文主题类别。文中使用来自汤姆森路透科学核心数据库(Thomson Reuters Web of Science Core
Collection)的无向的引文图数据集(对应于2000-2005年六个生物相关领域的所有论文)。这个数据集的节点标签对应于六个不同的领域的标签。该数据集共包含302,424个节点,平均度数为9.15。文中使用2000-2004年的数据集对所有算法进行训练,并使用2005年的数据进行测试(30%用于验证)。对于特征,文中使用节点的度。此外,按照Arora等人的sentence embedding方法处理论文摘要(使用GenSim word2vec实现训练的300维单词向量)。
第二个任务预测不同的Reddit帖子(posts)属于哪个社区。Reddit是一个大型的在线论坛,用户可以在这里对不同主题社区的内容进行发布和评论。作者在Reddit上对2014年9月发布的帖子构建了一个图形数据集。本例中的节点标签是帖子所属的社区或“subreddit”。文中对50个大型社区进行了抽样,并构建了一个帖子-帖子的图,如果同一个用户评论了两个帖子,就将这两个帖子连接起来。该数据集共包含232,965个帖子,平均度为492。文中将前20天的用于训练,其余的用于测试(30%用于验证)。对于特征,文中使用现成的300维GloVe CommonCrawl词向量对于每一篇帖子,将下面的内容连接起来:
考虑跨图进行泛化的任务,这需要了解节点的角色,而不是社区结构。文中在各种蛋白质-蛋白质相互作用(PPI)图中对蛋白质角色进行分类,每个图对应一个不同的人体组织。并且使用从Molecular
Signatures Database中收集的位置基因集、motif基因集和免疫学signatures作为特征,gene ontology作为标签(共121个)。图中平均包含2373个节点,平均度为28.8。文中将所有算法在20个图上训练,然后在两个测试图上预测F1 socres(另外两个图用于验证)。
GraphSAGE的核心:GraphSAGE不是试图学习一个图上所有node的embedding,而是学习一个为每个node产生embedding的映射
改进方向:扩展GraphSAGE以合并有向图或者多模式图;探索非均匀邻居采样函数
为什么GCN是transductive,为啥要把所有节点放在一起训练?
不一定要把所有节点放在一起训练,一个个节点放进去训练也是可以的。无非是如果想得到所有节点的embedding,那么GCN可以把整个graph丢进去,直接得到embedding,还可以直接进行节点分类、边的预测等任务。
其实,通过GraphSAGE得到的节点的embedding,在增加了新的节点之后,旧的节点也需要更新,这个是无法避免的,因为,新增加点意味着环境变了,那之前的节点的表示自然也应该有所调整。只不过,对于老节点,可能新增一个节点对其影响微乎其微,所以可以暂且使用原来的embedding,但如果新增了很多,极大地改变的原有的graph结构,那么就只能全部更新一次了。从这个角度去想的话,似乎GraphSAGE也不是什么“神仙”方法,只不过生成新节点embedding的过程,实施起来相比于GCN更加灵活方便了。在学习到了各种的聚合函数之后,其实就不用去计算所有节点的embedding,而是需要去考察哪些节点,就现场去计算,这种方法的迁移能力也很强,在一个graph上学得了节点的聚合方法,到另一个新的类似的graph上就可以直接使用了。
参考:GraphSAGE NIPS 2017 代码分析(Tensorflow版)
论文:Inductive Representation Learning on Large Graphs
知乎:【Graph Neural Network】GraphSAGE: 算法原理,实现和应用
知乎:网络表示学习: 淘宝推荐系统&&GraphSAGE
GNN 系列(三):GraphSAGE
GraphSAGE: GCN落地必读论文