Graph Convolutional Network
从信息传递的角度来分析GCN
-
- 在GCN中每个node都有自己的representation h i h_i hi
-
- 根据信息传递的范式,每个node会收到它的邻接node发送的message(representation)
-
- 每个node将收到邻居的message进行聚合得到 h i ^ \hat{h_i} hi^
-
- 聚合后的representation,进行线性或非线性的变换通过函数 f f f
-
- h i ^ \hat{h_i} hi^经过函数 f ( W u h i ^ ) = h i n e w f(W_u\hat{h_i}) = h^{new}_i f(Wuhi^)=hinew
-
- 根据以上计算得到的新 h i n e w h^{new}_i hinew,更新 h i n e w − − > h i h^{new}_i --> h_i hinew−−>hi
GCN的数学表示:
H ( l + 1 ) = σ ( D ~ − 1 2 A ~ D ~ − 1 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma(\tilde{D}^{\frac{-1}{2}}\tilde{A}\tilde{D}^{\frac{-1}{2}}H^{(l)}W^{(l)}) H(l+1)=σ(D~2−1A~D~2−1H(l)W(l))
- H ( l ) H^{(l)} H(l) : l t h l^{th} lth 层所有nodes的representation
- W ( l ) W^{(l)} W(l) : l t h l^{th} lth 层的权重矩阵
- D D D : degree matrix 度矩阵
- A A A : adjacency matrix 邻接矩阵
- D ~ \tilde{D} D~ : renormalization trick 重正则化技巧:给图中的每个节点增加自连接后的度矩阵
- A ~ \tilde{A} A~ : renormalization trick
- H ( 0 ) H^{(0)} H(0) : 输入,每个节点的初始化的特征
- H ( 0 ) H^{(0)} H(0) : shape : N × F i n N \times F_{in} N×Fin
- N : 图中的node的数量
- $F_{in} $: 输入特征的维度
- H ( o u t ) H^{(out)} H(out) : 输出,shape : N × F o u t N \times F_{out} N×Fout
Build a GCN using DGL
import dgl
import torch as th
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
from dgl import DGLGraph
gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node.data['h'])
h = self.activation(h)
return {'h' : h}
class GCN(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(GCN, self).__init__()
self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
def forward(self, g, feature):
g.ndata['h'] = feature
g.update_all(gcn_msg, gcn_reduce)
g.apply_nodes(func=self.apply_mod)
return g.ndata.pop('h')
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.gcn1 = GCN(1433, 16, F.relu)
self.gcn2 = GCN(16, 7, F.relu)
def forward(self, g, features):
x = self.gcn1(g, features)
x = self.gcn2(g, x)
return x
GCnet = Net()
print(GCnet)
Net(
(gcn1): GCN(
(apply_mod): NodeApplyModule(
(linear): Linear(in_features=1433, out_features=16, bias=True)
)
)
(gcn2): GCN(
(apply_mod): NodeApplyModule(
(linear): Linear(in_features=16, out_features=7, bias=True)
)
)
)
Load data(dgl built-in)
from dgl.data import citation_graph as citegrh
def load_cora_data():
data = citegrh.load_cora()
features = th.FloatTensor(data.features)
labels = th.LongTensor(data.labels)
mask = th.ByteTensor(data.train_mask)
g = data.graph
g.remove_edges_from(g.selfloop_edges())
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
return g, features, labels, mask
train model
import time
import warnings
import numpy as np
warnings.filterwarnings('ignore')
graph, features, labels, mask = load_cora_data()
optimizer = th.optim.Adam(GCnet.parameters(), lr=0.1)
dur = []
train_loss = []
for epoch in range(50):
if epoch >= 3:
t0 = time.time()
logits = GCnet(graph, features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
train_loss.append(loss.item())
print("Epoch %5d | Loss %.4f | Time(s) %.4f"%(epoch, loss.item(), np.mean(dur)))
Epoch 0 | Loss 0.9992 | Time(s) nan
Epoch 1 | Loss 1.0033 | Time(s) nan
Epoch 2 | Loss 2.8829 | Time(s) nan
Epoch 3 | Loss 1.7264 | Time(s) 0.2997
Epoch 4 | Loss 1.4124 | Time(s) 0.2961
Epoch 5 | Loss 0.8191 | Time(s) 0.2988
Epoch 6 | Loss 0.7352 | Time(s) 0.3071
Epoch 7 | Loss 0.6177 | Time(s) 0.3042
Epoch 8 | Loss 0.5425 | Time(s) 0.3030
Epoch 9 | Loss 0.4691 | Time(s) 0.3024
Epoch 10 | Loss 0.3825 | Time(s) 0.3019
Epoch 11 | Loss 0.3116 | Time(s) 0.3017
Epoch 12 | Loss 0.2253 | Time(s) 0.3036
Epoch 13 | Loss 0.1849 | Time(s) 0.3030
Epoch 14 | Loss 0.2047 | Time(s) 0.3027
Epoch 15 | Loss 0.1770 | Time(s) 0.3027
Epoch 16 | Loss 0.1390 | Time(s) 0.3023
Epoch 17 | Loss 0.0902 | Time(s) 0.3022
Epoch 18 | Loss 0.0822 | Time(s) 0.3023
Epoch 19 | Loss 0.0842 | Time(s) 0.3019
Epoch 20 | Loss 0.0796 | Time(s) 0.3027
Epoch 21 | Loss 0.0689 | Time(s) 0.3027
Epoch 22 | Loss 0.0667 | Time(s) 0.3025
Epoch 23 | Loss 0.0524 | Time(s) 0.3024
Epoch 24 | Loss 0.0486 | Time(s) 0.3025
Epoch 25 | Loss 0.0413 | Time(s) 0.3022
Epoch 26 | Loss 0.0382 | Time(s) 0.3021
Epoch 27 | Loss 0.0314 | Time(s) 0.3022
Epoch 28 | Loss 0.0282 | Time(s) 0.3019
Epoch 29 | Loss 0.0267 | Time(s) 0.3018
Epoch 30 | Loss 0.0254 | Time(s) 0.3018
Epoch 31 | Loss 0.0267 | Time(s) 0.3016
Epoch 32 | Loss 0.0248 | Time(s) 0.3016
Epoch 33 | Loss 0.0246 | Time(s) 0.3016
Epoch 34 | Loss 0.0240 | Time(s) 0.3014
Epoch 35 | Loss 0.0229 | Time(s) 0.3013
Epoch 36 | Loss 0.0225 | Time(s) 0.3014
Epoch 37 | Loss 0.0217 | Time(s) 0.3012
Epoch 38 | Loss 0.0210 | Time(s) 0.3012
Epoch 39 | Loss 0.0209 | Time(s) 0.3012
Epoch 40 | Loss 0.0204 | Time(s) 0.3011
Epoch 41 | Loss 0.0201 | Time(s) 0.3011
Epoch 42 | Loss 0.0200 | Time(s) 0.3015
Epoch 43 | Loss 0.0196 | Time(s) 0.3013
Epoch 44 | Loss 0.0194 | Time(s) 0.3013
Epoch 45 | Loss 0.0192 | Time(s) 0.3013
Epoch 46 | Loss 0.0189 | Time(s) 0.3014
Epoch 47 | Loss 0.0187 | Time(s) 0.3013
Epoch 48 | Loss 0.0185 | Time(s) 0.3013
Epoch 49 | Loss 0.0182 | Time(s) 0.3013
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 7))
plt.plot(train_loss)
plt.title('Train Loss')
plt.grid(True)
plt.show()
