这里简单介绍下使用DGL来实现GCN。
论文地址
我们解释了GraphConv模块下的内容。希望读者可以了解如何使用DGL的APIs来定义一个新的GNN层。
我们从消息传递的角度描述了一个图卷积神经网络层;具体数学描述见下。可以归结为下面的步骤,对于每个节点u:
我们使用DGL消息传递来实现步骤1,步骤2通过PyTorch的nn.Module实现。
我们首先像往常一样定义message和reduce函数。因为在一个节点u上的聚合只包含邻居表示 h v h_v hv的总和,我们可以简单使用内置函数:
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')
接下来我们定义GCNLayer模块。一个GCNLayer在所有节点上进行消息传递,然后运用一个全连接层。
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, g, feature):
# Creating a local scope so that all the stored ndata and edata
# (such as the `'h'` ndata below) are automatically popped out
# when the scope exits.
with g.local_scope():
g.ndata['h'] = feature
g.update_all(gcn_msg, gcn_reduce)
h = g.ndata['h']
return self.linear(h)
前向传播函数和PyTorch中其它常见的NNs模型中同样重要。我们可以像其它nn.Module一样初始化GCN。比如,让我们一个简单的包含两个GCN层的
神经网络,假定我们将要训练cora数据集 (输入特征大小为1433,类别数目为7)。最后的GCN层计算节点的嵌入,所以整体的最后一层不需要运用激活函数。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = GCNLayer(1433, 16)
self.layer2 = GCNLayer(16, 7)
def forward(self, g, features):
x = F.relu(self.layer1(g, features))
x = self.layer2(g, x)
return x
net = Net()
print(net)
Out:
Net(
(layer1): GCNLayer(
(linear): Linear(in_features=1433, out_features=16, bias=True)
)
(layer2): GCNLayer(
(linear): Linear(in_features=16, out_features=7, bias=True)
)
)
我们使用DGL构建好的数据模块加载cora数据集。
from dgl.data import citation_graph as citegrh
import networkx as nx
def load_cora_data():
data = citegrh.load_cora()
features = th.FloatTensor(data.features)
labels = th.LongTensor(data.labels)
train_mask = th.BoolTensor(data.train_mask)
test_mask = th.BoolTensor(data.test_mask)
g = DGLGraph(data.graph)
return g, features, labels, train_mask, test_mask
当模型训练完之后,我们可以通过下面的方法来评估模型在测试集上的表现:
def evaluate(model, g, features, labels, mask):
model.eval()
with th.no_grad():
logits = model(g, features)
logits = logits[mask]
labels = labels[mask]
_, indices = th.max(logits, dim=1)
correct = th.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
接下来我们训练网络如下:
import time
import numpy as np
g,features, labels, train_mask = load_cora_data()
optimizer = th.optim.Adam(net.parameters(), lr=1e-2)
dur = []
for epoch in range(50):
if epoch >= 3:
t0 = time.time()
net.train()
logits = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
acc = evaluate(net, g, features, labels, test_mask)
print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), acc, np.mean(dur)))
Out:
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3257: RuntimeWarning: Mean of empty slice.
out=out, **kwargs)
/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars
ret = ret.dtype.type(ret / rcount)
Epoch 00000 | Loss 1.9289 | Test Acc 0.3300 | Time(s) nan
Epoch 00001 | Loss 1.7816 | Test Acc 0.3640 | Time(s) nan
Epoch 00002 | Loss 1.6358 | Test Acc 0.4120 | Time(s) nan
Epoch 00003 | Loss 1.4936 | Test Acc 0.4610 | Time(s) 0.0729
Epoch 00004 | Loss 1.3725 | Test Acc 0.5280 | Time(s) 0.0727
Epoch 00005 | Loss 1.2673 | Test Acc 0.5800 | Time(s) 0.0726
Epoch 00006 | Loss 1.1750 | Test Acc 0.6210 | Time(s) 0.0727
Epoch 00007 | Loss 1.0935 | Test Acc 0.6280 | Time(s) 0.0728
Epoch 00008 | Loss 1.0213 | Test Acc 0.6460 | Time(s) 0.0729
Epoch 00009 | Loss 0.9558 | Test Acc 0.6770 | Time(s) 0.0729
Epoch 00010 | Loss 0.8957 | Test Acc 0.6960 | Time(s) 0.0728
Epoch 00011 | Loss 0.8399 | Test Acc 0.7180 | Time(s) 0.0728
Epoch 00012 | Loss 0.7875 | Test Acc 0.7300 | Time(s) 0.0728
Epoch 00013 | Loss 0.7390 | Test Acc 0.7430 | Time(s) 0.0731
Epoch 00014 | Loss 0.6937 | Test Acc 0.7480 | Time(s) 0.0731
Epoch 00015 | Loss 0.6512 | Test Acc 0.7550 | Time(s) 0.0731
Epoch 00016 | Loss 0.6110 | Test Acc 0.7680 | Time(s) 0.0733
Epoch 00017 | Loss 0.5728 | Test Acc 0.7720 | Time(s) 0.0733
Epoch 00018 | Loss 0.5364 | Test Acc 0.7740 | Time(s) 0.0734
Epoch 00019 | Loss 0.5022 | Test Acc 0.7790 | Time(s) 0.0735
Epoch 00020 | Loss 0.4709 | Test Acc 0.7810 | Time(s) 0.0737
Epoch 00021 | Loss 0.4429 | Test Acc 0.7820 | Time(s) 0.0737
Epoch 00022 | Loss 0.4177 | Test Acc 0.7790 | Time(s) 0.0737
Epoch 00023 | Loss 0.3946 | Test Acc 0.7780 | Time(s) 0.0738
Epoch 00024 | Loss 0.3730 | Test Acc 0.7760 | Time(s) 0.0738
Epoch 00025 | Loss 0.3528 | Test Acc 0.7760 | Time(s) 0.0737
Epoch 00026 | Loss 0.3342 | Test Acc 0.7720 | Time(s) 0.0738
Epoch 00027 | Loss 0.3169 | Test Acc 0.7710 | Time(s) 0.0739
Epoch 00028 | Loss 0.3006 | Test Acc 0.7710 | Time(s) 0.0739
Epoch 00029 | Loss 0.2853 | Test Acc 0.7720 | Time(s) 0.0739
Epoch 00030 | Loss 0.2709 | Test Acc 0.7720 | Time(s) 0.0738
Epoch 00031 | Loss 0.2576 | Test Acc 0.7760 | Time(s) 0.0738
Epoch 00032 | Loss 0.2452 | Test Acc 0.7760 | Time(s) 0.0738
Epoch 00033 | Loss 0.2334 | Test Acc 0.7780 | Time(s) 0.0737
Epoch 00034 | Loss 0.2222 | Test Acc 0.7790 | Time(s) 0.0737
Epoch 00035 | Loss 0.2116 | Test Acc 0.7780 | Time(s) 0.0737
Epoch 00036 | Loss 0.2015 | Test Acc 0.7800 | Time(s) 0.0737
Epoch 00037 | Loss 0.1920 | Test Acc 0.7790 | Time(s) 0.0737
Epoch 00038 | Loss 0.1829 | Test Acc 0.7760 | Time(s) 0.0736
Epoch 00039 | Loss 0.1744 | Test Acc 0.7720 | Time(s) 0.0736
Epoch 00040 | Loss 0.1664 | Test Acc 0.7730 | Time(s) 0.0736
Epoch 00041 | Loss 0.1588 | Test Acc 0.7700 | Time(s) 0.0736
Epoch 00042 | Loss 0.1515 | Test Acc 0.7710 | Time(s) 0.0735
Epoch 00043 | Loss 0.1447 | Test Acc 0.7710 | Time(s) 0.0735
Epoch 00044 | Loss 0.1382 | Test Acc 0.7720 | Time(s) 0.0735
Epoch 00045 | Loss 0.1321 | Test Acc 0.7760 | Time(s) 0.0735
Epoch 00046 | Loss 0.1264 | Test Acc 0.7750 | Time(s) 0.0735
Epoch 00047 | Loss 0.1211 | Test Acc 0.7730 | Time(s) 0.0735
Epoch 00048 | Loss 0.1160 | Test Acc 0.7730 | Time(s) 0.0735
Epoch 00049 | Loss 0.1113 | Test Acc 0.7730 | Time(s) 0.0735
数学上,GCN模型满足下面的式子:
这里, H ( l ) H^{(l)} H(l)表示网络的 l t h l^{th} lth层, σ \sigma σ是非线性变换,W是该层的权重矩阵。D和A,和平时见到的一样,分别表示度矩阵和邻接矩阵。~是一个renormalization技巧,这里我们给图中的每个节点增加了一个自环,并构建了相应的度矩阵和邻接矩阵。输入 H ( 0 ) H^{(0)} H(0)的形状是ND,其中N是节点数目,D是输入特征数目。我们可以像这样将多个层连起来,来生成节点级的表示,其大小为NF,其中F为输出特征向量的维度。
利用稀疏矩阵乘法核可以有效地实现该方程(比如Kipf的pygcn代码)。事实上,上面的DGL的实现由于对内置函数的利用而使用了这一技巧。
总的运行时间:0分13.200秒