关于Pytorch和Pytorch Geometric(PyG)框架下重现GCN代码的理解

目录

一. 理解MessagePassing

二. 关于数据集的处理(以Cora为例)

三. 重现GCN代码分析


最近学习了,Pytorch和Pytorch Geometric(PyG)框架下重现SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS的代码,下面是关于Pytorch Geometric及代码的理解。

一. 理解MessagePassing

PyTorch Geometric provides the torch_geometric.nn.MessagePassing base class

1. Creating Message Passing Networks

将卷积运算符推广到不规则域通常表示为邻域聚合或消息传递方案:

       (1)

:代表节点i在第(k)层的特征

:代表(可选)代表边i,j的特征

其中:代表可微分置换不变函数,例如(求和、求平均值、求最大值),而γϕ为微分函数,例如MLP(多层感知器)

2. The “MessagePassing” Base Class

PyTorch Geometric provides the torch_geometric.nn.MessagePassing base class

定义卷积层的时候继承此基类,只需要定义函数ϕ,比如:message(),和γ函数,比如:update(),以及the aggregation scheme to use,比如:aggr='add', aggr='mean' or aggr='max'

 

3.Implementing the GCN Layer

The GCN layer is mathematically defined as

其中,邻域节点特征首先转化为权重矩阵Θ,被度归一化,然后加起来,这个公式,被分解为如下5步:

  1. Add self-loops to the adjacency matrix.
  2. Linearly transform node feature matrix.
  3. Normalize node features in ϕ.
  4. Sum up neighboring node features (“add” aggregation).
  5. Return new node embeddings in γ.

Steps 1-2 are typically computed before message passing takes place. Steps 3-5 can be easily processed using the (torch_geometric.nn.MessagePassing) base class. 

二. 关于数据集的处理(以Cora为例)

1.导入数据集

import os
import os.path as osp
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath('__file__')), 'data', dataset)

##加载数据集
dataset = Planetoid(path, dataset, T.NormalizeFeatures())
data = dataset[0]

2. 测试节点特征已归一化

dataset[0]中包含了,数据集的所有信息,具体如下:

(1)data.x (节点特征,大小为2708*1433,tensor)

###测试####测试特征归一化  data.x已经归一化 
x = data.x
nonzero_x = torch.nonzero(x[0])
nonzero_x = torch.nonzero(x[1])
print(torch.nonzero(x[0]))
print(torch.nonzero(x[1]))
print(torch.nonzero(x[0]).shape)
print(torch.nonzero(x[1]).shape)
print(x[0,19])
print(x[1,19])

结果如下: 

tensor([[  19],
        [  81],
        [ 146],
        [ 315],
        [ 774],
        [ 877],
        [1194],
        [1247],
        [1274]])
tensor([[  19],
        [  88],
        [ 149],
        [ 212],
        [ 233],
        [ 332],
        [ 336],
        [ 359],
        [ 472],
        [ 507],
        [ 548],
        [ 687],
        [ 763],
        [ 808],
        [ 889],
        [1058],
        [1177],
        [1254],
        [1257],
        [1262],
        [1332],
        [1339],
        [1349]])
torch.Size([9, 1])
torch.Size([23, 1])
tensor(0.1111)
tensor(0.0435)

(2) 图结构及其他

print("dataset.num_features:",data)
print("dataset.num_features:",dataset.num_features)
#特征数:1433
print("dataset.num_features:",dataset.num_features)
#分类数:7
print("dataset.num_classes:",dataset.num_classes)

#查看张量X的小大([2708, 1433]),2708个节点,每个节点1433个特征
print("data.x.shape:",data.x.shape)
print("data.edge_index.shap:",data.edge_index.shape)
##edge_index 代表边([2, 10556]),有边的节点对
x, edge_index = data.x, data.edge_index
print("edge_index.shape:",edge_index.shape)

##Edge feature matrix with shape [num_edges, num_edge_features]

print("data.edge_attr:",data.edge_attr)

#输出结果,2708 每个节点的标签分类情况
print("data.y:",data.y)
print("data.y[data.train_mask]",data.y[data.train_mask])

##有向图还是无向图
print("data.is_undirected:",data.is_undirected())

###train_mask denotes against which nodes to train (140 nodes)
##val_mask denotes which nodes to use for validation, e.g., to perform early stopping (500 nodes)
##test_mask denotes against which nodes to test (1000 nodes)
print(data.train_mask.sum().item())
print(data.val_mask.sum().item())
print(data.test_mask.sum().item())

print("data.train_mask:",data.train_mask)

结果显示:

dataset.num_features: Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
dataset.num_features: 1433
dataset.num_features: 1433
dataset.num_classes: 7
data.x.shape: torch.Size([2708, 1433])
data.edge_index.shap: torch.Size([2, 10556])
edge_index.shape: torch.Size([2, 10556])
data.edge_attr: None
data.y: tensor([3, 4, 4,  ..., 3, 3, 3])
data.y[data.train_mask] tensor([3, 4, 4, 0, 3, 2, 0, 3, 3, 2, 0, 0, 4, 3, 3, 3, 2, 3, 1, 3, 5, 3, 4, 6,
        3, 3, 6, 3, 2, 4, 3, 6, 0, 4, 2, 0, 1, 5, 4, 4, 3, 6, 6, 4, 3, 3, 2, 5,
        3, 4, 5, 3, 0, 2, 1, 4, 6, 3, 2, 2, 0, 0, 0, 4, 2, 0, 4, 5, 2, 6, 5, 2,
        2, 2, 0, 4, 5, 6, 4, 0, 0, 0, 4, 2, 4, 1, 4, 6, 0, 4, 2, 4, 6, 6, 0, 0,
        6, 5, 0, 6, 0, 2, 1, 1, 1, 2, 6, 5, 6, 1, 2, 2, 1, 5, 5, 5, 6, 5, 6, 5,
        5, 1, 6, 6, 1, 5, 1, 6, 5, 5, 5, 1, 5, 1, 1, 1, 1, 1, 1, 1])
data.is_undirected: True
140
500
1000
data.train_mask: tensor([ True,  True,  True,  ..., False, False, False])

三. 重现GCN代码分析

1.定义卷积层

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels=1433]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        ## x.size(0)=2708
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        ##print("x = self.lin(x)***:",x,x.shape)
        ##2708*1433
        ##print(x)
        
        # Step 2: Linearly transform node feature matrix.2708*1433
        x = self.lin(x)

        ##2708*16
        ##print("x = self.lin(x):",x,x.shape)
        
        # Step 3-5: Start propagating messages.
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j, edge_index, size):
        # x_j has shape [E, out_channels]

        # Step 3: Normalize node features.
        row, col = edge_index
        
        #计算每个结点的度
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        
        ##
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        

        ##1*(10556+2708=13264) * (10556+2708=13264)*out_channels = 1*out_channels
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # aggr_out has shape [N=2708, out_channels]

        # Step 5: Return new node embeddings.
        return aggr_out

2.定义网络模型

class Net(torch.nn.Module):
    #torch.nn.Module 是所有神经网络单元的基类
    def __init__(self):
        super(Net, self).__init__()###复制并使用Net的父类的初始化方法,即先运行nn.Module的初始化函数
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

3.设置GPU,并定义优化器

##############################设置GPU、定义优化器#############################
device = torch.device('cpu')
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

4.定义训练函数

##############################定义训练函数#############################
def train():
    model.train()
     # 在反向传播之前,先将梯度归0
    optimizer.zero_grad()
    # 将误差反向传播
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    # 更新参数
    optimizer.step()

5.定义测试函数

##############################定义测试函数#############################
def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

6.训练并测试模型

##############################训练并测试函数#############################
best_val_acc = test_acc = 0
for epoch in range(1, 201):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
        
    #打印有哪些参与训练的参数
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)
            
    log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    print(log.format(epoch, train_acc, best_val_acc, test_acc))

时间仓促,后续补充更新

你可能感兴趣的:(谱图论)