课程链接:CS224W: Machine Learning with Graphs
课程视频:【课程】斯坦福 CS224W: 图机器学习 (2019 秋 | 英字)
我们上节课讨论了对网络进行编码的架构。同时,我们也介绍了GCN的核心思想——Aggregate neighbours,并讨论了如何使用神经网络去实现。也就说,我们之前讨论的,是如何对网络进行embedding;我们今天要讲的是,如何从embedding生成网络。
网络的生成有很多实际的应用:
问题的引入——有真实的网络 G G G,和人造的网络 G ′ G' G′,那么:
网络的生成主要涉及两个任务:
可以说,网络生成是有趣的,也是很难的任务,它的难主要体现在以下几个方面:
假设我们要通过一组节点数据 { x i } \{x_i\} { xi}来学习网络的生成模型。
p d a t a ( x ) p_{data}(x) pdata(x)是数据的分布(data distribution),实际上我们并不可能知道这个分布,但是我们可以通过对 x i x_i xi的采样(sampling)来得到这个分布,即 x i ∽ p d a t a ( x ) x_i \backsim p_{data}(x) xi∽pdata(x)。
p m o d e l ( x ; θ ) p_{model}(x;\theta) pmodel(x;θ)是模型(model),参数 θ \theta θ用来估计 p d a t a ( x ) p_{data}(x) pdata(x)。
那么,我们的目标就是:
(1)让 p m o d e l ( x ; θ ) p_{model}(x;\theta) pmodel(x;θ)接近于 p d a t a ( x ) p_{data}(x) pdata(x);
核心理论——极大似然估计
(2)确保我们可以从 p m o d e l ( x ; θ ) p_{model}(x;\theta) pmodel(x;θ)采样,并生成网络。
这里的函数 f ( ⋅ ) f(·) f(⋅)采用深度神经网络实现。
网络的生成是通过不断地增加节点和边来实现的。
对应一个确定的节点顺序 π \pi π,图 G G G可以表示为节点和边的序列 S π S^{\pi} Sπ:
S π S^{\pi} Sπ实际上是序列的序列。对应的每一个序列 S i π S_i^{\pi} Siπ,都有两个层次的操作:
A graph + a node ordering = A sequence of sequences!
而解决序列问题,我们自然而然地就能想到利用RNN来实现。
GraphRNN包括两个部分:
那么,我们怎样利用RNN来生成序列呢?
对于一个RNN单元来说,有状态 s t s_t st,输入 x t x_t xt,输出 y t y_t yt。
对于序列的表示,可以将RNN单元重复连接。开始和结束都定义一个标识符,开始的标识符 s 0 = S O S s_0=SOS s0=SOS,结束的标识符 y T = E O S y_T=EOS yT=EOS;上一个状态的输出是下一个状态的输入,即 x t + 1 = y t x_{t+1}=y_t xt+1=yt。
在上述模型的基础上,我们需要给RNN模型增加随机性。首先,我们要明确的是,我们的目标是使用RNN来估计 ∏ k = 1 n p m o d e l ( x t ∣ x 1 , ⋯ , x t − 1 ; θ ) \prod_{k=1}^n p_{model}(x_t|x_1, \cdots, x_{t-1}; \theta) ∏k=1npmodel(xt∣x1,⋯,xt−1;θ)。那么, x t + 1 x_{t+1} xt+1是 y t : x t + 1 ∽ y t y_t:x_{t+1} \backsim y_t yt:xt+1∽yt的取样。
RNN每一步的输出是一个概率向量,下一个状态的输入时基于该概率向量的一个取样。
模型的测试
假设我们有一个已经训练好的模型, y y y服从伯努利分布, y 1 = 0.9 y_1=0.9 y1=0.9表示有0.9的概率生成1,即有边连接;有 1 − 0.9 = 0.1 1-0.9=0.1 1−0.9=0.1的概率生成0,即没有边连接。
模型训练
在进行模型训练的时候,有一个原则——Teacher Forcing,也就是加入我们检测到真实的边的序列为 [1,0,…],在训练时我们用真实的这个序列作为输出。
损失函数定义为Binary cross entropy,训练目标是使损失函数最小化:
L = − [ y 1 ∗ log ( y 1 ) + ( 1 − y 1 ∗ ) log ( 1 − y 1 ) ] L=-[y_1^* \log(y_1)+(1-y_1^*)\log (1-y_1)] L=−[y1∗log(y1)+(1−y1∗)log(1−y1)]
y 1 ∗ y_1^* y1∗是真实结果。如果 y 1 ∗ = 1 y_1^*=1 y1∗=1,则 L = − log ( y 1 ) L=-\log(y_1) L=−log(y1), y 1 y_1 y1越大(越接近 y 1 ∗ y_1^* y1∗), L L L越小;如果 y 1 ∗ = 0 y_1^*=0 y1∗=0,则 L = − log ( 1 − y 1 ) L=-\log (1-y_1) L=−log(1−y1), y 1 y_1 y1越小(越接近 y 1 ∗ y_1^* y1∗), L L L越小。这样,就可以使预测值 y 1 y_1 y1越来越接近实际值 y 1 ∗ y_1^* y1∗。预测值 y 1 y_1 y1通过RNN计算得到,可以通过反向传播不断优化RNN的参数。
然而,我们还是面临一个问题,因为每一个新生成的点都有可能和之前的点进行关联,也就是说,当图的节点数量很大时,我需要记住很长的依赖关系来实现边的生成。