- 题目:Meta-GNN: On Few-shot Node Classification in Graph Meta-learning
- 会议:CIKM (CCF-B)
- 链接:https://dl.acm.org/doi/pdf/10.1145/3357384.3358106
- 源码:https://github.com/AI-DL-Conference/Meta-GNN
- 时间:2019年11月
- 摘要:用GNN来实现小样本分类,运用元学习的策略来训练模型。用到的gnn模型是SGC 和 GCN ,没有对gnn模型任何创新性的改造,但是将gnn上的元学习实验过程描述得很清楚,并且提供了各种baseline的代码,对gnn+小样本的实验有参考价值,总体而言比较基础。
贡献:
定义无向图: G = ( V , E , A , X ) \mathcal{G}=(V, E, \mathbf{A}, \mathbf{X}) G=(V,E,A,X)
问题定义:
GNN结合图结构信息和节点特征X来学习结点的向量表示hv,通常采用邻域聚合的方法。
下图展示了一个结点聚合的例子,在图通过GNN的第一层时,红色节点对节点1、2和3的信息进行聚合,在第二层之后,红色节点对节点5和节点6的信息进行聚合。
因此第l层定义为:
a v ( l ) = h v ( l − 1 ) ⋅ AGGREGATE ( l ) ( { h u ( l − 1 ) : u ∈ N ( v ) } ) h v ( l ) = h v ( l − 1 ) ⋅ COMBINE ( l ) ( a v ( l ) ) \begin{aligned} \mathbf{a}_{v}^{(l)} &=\mathbf{h}_{v}^{(l-1)} \cdot \text { AGGREGATE }^{(l)}\left(\left\{\mathbf{h}_{u}^{(l-1)}: u \in \mathcal{N}(v)\right\}\right) \\ \mathbf{h}_{v}^{(l)} &=\mathbf{h}_{v}^{(l-1)} \cdot \operatorname{COMBINE}^{(l)}\left(\mathbf{a}_{v}^{(l)}\right) \end{aligned} av(l)hv(l)=hv(l−1)⋅ AGGREGATE (l)({hu(l−1):u∈N(v)})=hv(l−1)⋅COMBINE(l)(av(l))
h v ( l ) \mathbf{h}_{v}^{(l)} hv(l) :第l层 结点v的特征向量
初始化 h ( 0 ) = X \mathbf{h}^{(0)}=\mathbf{X} h(0)=X
N ( v ) \mathcal{N}(v) N(v) 表示结点 v 的邻居节点集合
AGGREGATE 和 COMBINE 操作的选择对任务性能至关重要
此处graphSAGE为例子,介绍一下 聚合和组合的思想。
补充学习 graphSAGE
basic GNN: h v k = σ ( W k ∑ u ∈ N ( v ) h u k − 1 ∣ N ( v ) ∣ + B k h v k − 1 ) \mathbf{h}_{v}^{k}=\sigma\left(\mathbf{W}_{k} \sum_{u \in N(v)} \frac{\mathbf{h}_{u}^{k-1}}{|N(v)|}+\mathbf{B}_{k} \mathbf{h}_{v}^{k-1}\right) hvk=σ(Wk∑u∈N(v)∣N(v)∣huk−1+Bkhvk−1)
graphSAGE : h v k = σ ( [ A k ⋅ AGG ( { h u k − 1 , ∀ u ∈ N ( v ) } ) , B k h v k − 1 ] ) \mathbf{h}_{v}^{k}=\sigma\left(\left[\mathbf{A}_{k} \cdot \operatorname{AGG}\left(\left\{\mathbf{h}_{u}^{k-1}, \forall u \in N(v)\right\}\right), \mathbf{B}_{k} \mathbf{h}_{v}^{k-1}\right]\right) hvk=σ([Ak⋅AGG({huk−1,∀u∈N(v)}),Bkhvk−1])
在basic GNN中对邻居的聚合采用的是取平均的操作,对邻居结点和当前结点特征的融合采用的是直接相加的方式。为了进一步提升模型的表达能力,graphSAGE提出了AGGREGATE 和 COMBINE 的方法。
最后论文挑选了SGC 和 GCN 来实现Meta-GNN模型。
步骤:
训练步骤:
对于每一个小任务,将支持集送入Meta-GNN,计算交叉熵损失。
L T i ( f θ ) = − ( ∑ x i s , y i s y i s log f θ ( x i s ) + ( 1 − y i s ) log ( 1 − f θ ( x i s ) ) ) \mathcal{L}_{\mathcal{T}_{i}}\left(f_{\boldsymbol{\theta}}\right)=-\left(\sum_{\boldsymbol{x}_{i s}, y_{i s}} y_{i s} \log f_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{i s}\right)+\left(1-y_{i s}\right) \log \left(1-f_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{i s}\right)\right)\right) LTi(fθ)=−(∑xis,yisyislogfθ(xis)+(1−yis)log(1−fθ(xis)))
然后我们执行参数更新,在任务Ti中使用一个或几个步骤的简单梯度下降(为了简单起见,后续只描述一个梯度更新)
θ i ′ = θ − α 1 ∂ L T i ( f θ ) ∂ θ \theta_{i}^{\prime}=\theta-\alpha_{1} \frac{\partial \mathcal{L}_{\mathcal{T}_{i}}\left(f_{\theta}\right)}{\partial \boldsymbol{\theta}} θi′=θ−α1∂θ∂LTi(fθ)
α1为任务学习率,训练模型参数以优化fθ ’ 在元训练任务中的性能。更具体地说,元目标如下:
θ = arg min θ ∑ T i ∼ p ( T ) L T i ( f θ i ′ ) \boldsymbol{\theta}=\underset{\theta}{\arg \min } \sum_{\mathcal{T}_{i} \sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_{i}}\left(f_{\theta_{i}^{\prime}}\right) θ=θargmin∑Ti∼p(T)LTi(fθi′)
也就是说,每个任务都使用在自己的任务上进行一次梯度更新得到 θ ′ \theta_^{′} θi′,最终会得到多个 θ ′ \theta_^{′} θi′,优化的最终目的是找一个最优的,使得每一个 θ ′ \theta_^{′} θi′在自身的任务上损失最小。
对于元测试,我们只需要将新的小样本学习任务支持集的节点输入到元Meta-GNN中,并通过一个或少量梯度下降步骤更新参数θ '。因此,在查询集上可以很容易地评估Meta-GNN的性能。
2️⃣Baseline:
3️⃣实施: