PyG和DGL是GNN领域的两大框架,两大框架的底层都是基于消息传递机制,即PyG中的MessagePassing基类和DGL中的Message Passing Paradigm。
关于DGL的消息传递范式,前面已经有几篇文章进行过讲解:
本篇文章使用Citeseer网络。Citeseer网络是一个引文网络,节点为论文,一共3327篇论文。论文一共分为六类:Agents、AI(人工智能)、DB(数据库)、IR(信息检索)、ML(机器语言)和HCI。如果两篇论文间存在引用关系,那么它们之间就存在链接关系。
dataset = Planetoid(root='data', name='CiteSeer')
dataset = dataset[0]
dataset.edge_index, _ = add_self_loops(dataset.edge_index)
dataset = dataset.to(device)
num_in_feats, num_out_feats = dataset.num_node_features, torch.max(dataset.y).item() + 1
MessagePassing是PyG中定义的一个有关消息传递机制的基类,它通过自动处理消息传播来帮助创建此类消息传递图神经网络。用户只需定义消息函数message、聚合函数aggregate以及更新函数update,就能实现自定义GNN,这点和DGL类似。
消息传递的基本原理:
其中 x i ( k ) x_i^{(k)} xi(k)表示节点 i i i经过第 k k k层更新后的特征, e j , i e_{j,i} ej,i表示从节点 j j j到节点 i i i之间边的特征。
ϕ ( k ) \phi^{(k)} ϕ(k)表示第 k k k层的消息函数:例如可以将边特征和两个节点的特征求平均以得到新的特征。
□ \square □表示聚合函数:例如可以将节点 i i i的所有邻居节点经过消息函数处理后的特征进行求和,或者求和后再加上节点 i i i本身的特征 x i ( k − 1 ) x_i^{(k-1)} xi(k−1)。
γ ( k ) \gamma^(k) γ(k)表示第 k k k层的更新函数:例如可以将聚合后的特征经过简单的线性变换或者激活函数。
GCN的具体数学原理为:
对应到上面所讲的MessagePassing:
message:将节点 i i i的所有邻居节点进行一个简单的线性变换,然后将这些特征乘上一个权重 d e g ( i ) − 1 2 ⋅ d e g ( j ) − 1 2 deg(i)^{-\frac{1}{2}} \cdot deg(j)^{-\frac{1}{2}} deg(i)−21⋅deg(j)−21,其中 d e g ( i ) deg(i) deg(i)表示节点 i i i添加自环后的度。
因此,具体的代码实现为:
def message(self, x, edge_index):
x = self.linear(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
x_j = x[col] # target nodes
x_j = norm.view(-1, 1) * x_j # 12431条边上target nodes的feature * norm
return x_j
首先,我们将节点特征 x x x经过一个线性变换:
x = self.linear(x)
然后,得到所有节点的度的-0.5次方并计算乘积:
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
其中,得到所有节点度的操作为:
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
其中col就是图中所有边中的目标节点,方法逻辑:统计col中从0到x.size(0)-1
(节点数)中每个数出现的次数,该次数就是节点的度。注意,传入row时计算的实际上是出度,而传入col时计算的是入度,对于本文中使用的无向图来讲,二者计算结果一致。
得到所有节点的度后,计算每条边上源节点和目标节点度的-0.5次方的乘积以得到权重:
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
最后,我们需要将邻居节点(目标节点)的特征乘上该权重并返回:
x_j = x[col] # target nodes
x_j = norm.view(-1, 1) * x_j # 12431条边上target nodes的feature * norm
return x_j
aggregate:将节点 i i i所有邻居节点 x j x_j xj的经过消息函数处理后的特征求和。
def aggregate(self, x_j, edge_index):
# x_j为target nodes的归一化特征
row, col = edge_index
# row(12431), x_j(12431, out_channels)
out = scatter(x_j, row, dim=0, reduce='sum')
return out
这里使用了torch_scatter中的scatter方法来对所有邻居节点的特征进行聚合。具体来讲,就是根据row中相同索引对应的 x j x_j xj中的元素进行求和处理,然后按照索引进行排序,其中 x j x_j xj为前面消息函数求得的所有边中目标节点的加权特征。
比如一共12431条边,那么row
就是12431条边中源节点的索引值,假设索引0一共出现在了5个位置(节点0的出度为5),那么最终得到的out的第一个元素就是将 x j x_j xj中这5个位置的特征求和,也就是节点0的5个邻居节点的特征求和。
这样,经过scatter方法处理后,我们就得到了所有节点的更新后的特征值。
关于torch_scatter.scatter()的具体使用方法可以参考:torch_scatter.scatter()的使用方法详解。
观察PyG中对GCN的定义:
因此,我们可以将update简单理解为加上一个bias,即:
def update(self, out):
return out + self.bias
在PyG中,MessagePassing通过调用propagate方法来实现图上的一次卷积操作,即前面提到的message、aggregate以及update操作:
def propagate(self, x, edge_index):
out = self.message(x, edge_index)
out = self.aggregate(out, edge_index)
out = self.update(out)
return out
因此,一个完整的GCNConv搭建如下:
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add')
self.linear = nn.Linear(in_channels, out_channels, bias=False)
self.bias = Parameter(torch.Tensor(out_channels))
def message(self, x, edge_index):
x = self.linear(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
x_j = x[col] # target nodes
x_j = norm.view(-1, 1) * x_j # 12431条边上target nodes的feature * norm
return x_j
def aggregate(self, x_j, edge_index):
# x_j为target nodes的归一化特征
row, col = edge_index
# row(12431), x_j(12431, out_channels)
out = scatter(x_j, row, dim=0, reduce='sum')
return out
def update(self, out):
return out + self.bias
def propagate(self, x, edge_index):
out = self.message(x, edge_index)
out = self.aggregate(out, edge_index)
out = self.update(out)
return out
def forward(self, x, edge_index):
return self.propagate(x, edge_index)
一个简单的两层GCN搭建如下:
class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_node_features, 32)
self.conv2 = GCNConv(32, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = F.softmax(x, dim=1)
return x
训练:
def train(model):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
loss_function = torch.nn.CrossEntropyLoss().to(device)
model.train()
min_epochs = 10
best_model = None
min_val_loss = 5
for epoch in range(200):
out = model(dataset)
optimizer.zero_grad()
loss = loss_function(out[dataset.train_mask], dataset.y[dataset.train_mask])
loss.backward()
optimizer.step()
# validation
val_loss = get_val_loss(model)
if epoch + 1 >= min_epochs and val_loss < min_val_loss:
min_val_loss = val_loss
best_model = copy.deepcopy(model)
print('Epoch: {:3d} train_Loss: {:.5f} val_loss: {:.5f}'.format(epoch, loss.item(), val_loss))
model.train()
return best_model
测试:
def test(model):
model.eval()
_, pred = model(dataset).max(dim=1)
correct = int(pred[dataset.test_mask].eq(dataset.y[dataset.test_mask]).sum().item())
acc = correct / int(dataset.test_mask.sum())
print('GCN Accuracy: {:.4f}'.format(acc))
实验结果:69.8%的准确率。
代码地址:GNNs-for-Node-Classification。原创不易,下载时请给个follow和star!感谢!!