GNN实战——KarateClub数据集

GNN:graph neural network 图神经网络,是⼀种连接模型,通过⽹络中节点之间的信息传递(message passing)的⽅式来获取图中的依存关系(dependence of graph),GNN通过从节点任意深度的邻居来更新该节点状态,这个状态能够表示状态信息。由于 GNN 在图节点之间强大的建模功能,使得与图分析相关的研究领域取得了突破。图神经网络(GNN)是一类基于深度学习的处理图域信息的方法。由于其较好的性能和可解释性,现已被广泛应用到各个领域。涵盖了推荐系统、组合优化、计算机视觉、物理 / 化学以及药物发现等领域。

一、数据集介绍

数据集中只有一张图。
GNN实战——KarateClub数据集_第1张图片
该图描述了一个空手道俱乐部会员的社交关系,以34名会员作为节点,如果两位会员在俱乐部之外仍保持社交关系,则在节点间增加一条边。
每个节点具有一个34维的特征向量,一共有78条边。在收集数据的过程中,管理人员 John A 和 教练 Mr. Hi(化名)之间产生了冲突,会员们选择了站队,一半会员跟随 Mr. Hi 成立了新俱乐部,剩下一半会员找了新教练或退出了俱乐部。通过收集到的图数据,Zachary 进行了分类,除1名会员外都分类正确。将原图进行抽象可得到下图:
GNN实战——KarateClub数据集_第2张图片

二、GNN实战

1. 导入所需的包

%matplotlib inline
import torch
import networkx as nx
import matplotlib.pyplot as plt
# KarateClub是torch_geometric内置的数据集
from torch_geometric.datasets import KarateClub

注:torch_geometric库的安装不能直接pip install,具体的安装方法可以参考之前的blog:https://blog.csdn.net/m0_51339444/article/details/128611141

2. 定义可视化函数

def visualize_graph(G, color):
    plt.figure(figsize=(5,5))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap='Set2')
    plt.show()

3. 导入并查看KarateClub数据集

dataset = KarateClub()
print(f'Dataset: {dataset}:')
print(f'Number of the graphs: {len(dataset)}')
print(f'Number of the features: {dataset.num_features}')
print(f'Number of the classes: {dataset.num_classes}')

GNN实战——KarateClub数据集_第3张图片

data = dataset[0]
print(data)

在这里插入图片描述

# edge_index是邻接矩阵,表示每两个点之间的关联
edge_index = data.edge_index
# 打印出每个点分别和谁有关系
print(edge_index.t())

这里对上一个运行结果解释一下,这是整个数据集的全部生态环境了,x是特征,就是一个一个的点,第一个34表示一共有34个点,即34个样本,第二个34表示每个样本是34维的向量(即34个特征);edge_index是邻接矩阵,表示每两个点之间的关联,第一个元素一定是2,表示两个点之间的边,156表示一共有156个关系,即156条边;train_mask记录了34个数据中有标签与否,有标签是True,没有标签是False。

4. 使用networkx进行可视化展示

# 将处理好(对应的标准格式)的data传入to_networkx,再传入visualize_graph(最上面自己定义的)绘图
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

GNN实战——KarateClub数据集_第4张图片

5. 搭建网络

这里会使用到torch_geometric的方法(封装好的函数),有疑问的地方可以去官网查询API,这里拍个链接:https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)  # 两个参数分别为输入特征和输出特征
        self.conv2 = GCNConv(4,4)
        self.conv3 = GCNConv(4,2)
        self.classifier = Linear(2, dataset.num_classes)   
    # x是特征,没经过一层后数据都是不断变化的,即x 变成h,h不断变成新的h,而edge_index邻接矩阵是一直不变的,谁和谁之间有联系是不变的    
    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index) # 输入特征和邻接矩阵
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()        
        # 分类层
        out = self.classifier(h)        
        # out是输出,h是中间结果(conv3的输出)(一个2维的向量(方便绘图打印))
        return out, h
            
model = GCN()
print(model)

由于数据集比较小,因此搭建小网络即可,网络参数如下:
GNN实战——KarateClub数据集_第5张图片

6. 进行embedding操作并可视化

def visualize_embedding(h, color, epoch=None, loss=None):
    plt.figure(figsize=(5,5))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap='Set2')
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()
model = GCN()
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')
visualize_embedding(h, color=data.y)

GNN实战——KarateClub数据集_第6张图片

7. 训练模型

import time 

model = GCN()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(data):
    optimizer.zero_grad()
    out, h = model(data.x, data.edge_index) 
    # 这里体现了半监督的思想,只拿有标签的计算损失,没有标签的不参与计算
    loss = loss_function(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss, h

for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
        time.sleep(0.3)

在这里插入图片描述
在这里插入图片描述
可以看到,随着epoch的增大,损失函数逐渐收敛,可视化结果逐渐将三种颜色分成了三个类别(类似聚类的结果)。

你可能感兴趣的:(GNN,GNN,深度学习,python,pytorch)