利用torch_geometric运行gcn

pytorch出了图计算的工具torch_geometric后,gcn的实现就简单了,直接封装好了

首先需要安装torch_geometric

$ pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
$ pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
$ pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
$ pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
$ pip install torch-geometric

可以将${CUDA}换成cu110,cu102.cu101

因为我是使用的cuda11.0,因此

$ pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html
$ pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html
$ pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html
$ pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.0+cu110.html
$ pip install torch-geometric

 安装好以后,就可以直接运行了,这里运行的数据集是Cora

数据集直接放到百度盘上,因为默认的地址是github,很容易连接失败

链接:https://pan.baidu.com/s/1mmk2oESN25WdyRKTqgo3sg 
提取码:7pd6 
解压到main.py的路径就行:

├── [  52]  data
│   ├── [  34]  Cora
│   │   ├── [  66]  processed
│   │   │   ├── [ 15M]  data.pt
│   │   │   ├── [ 431]  pre_filter.pt
│   │   │   └── [ 431]  pre_transform.pt
│   │   └── [ 171]  raw
│   │       ├── [251K]  ind.cora.allx
│   │       ├── [ 47K]  ind.cora.ally
│   │       ├── [ 58K]  ind.cora.graph
│   │       ├── [4.9K]  ind.cora.test.index
│   │       ├── [145K]  ind.cora.tx
│   │       ├── [ 27K]  ind.cora.ty
│   │       ├── [ 22K]  ind.cora.x
│   │       └── [4.0K]  ind.cora.y
│   ├── [  34]  ENZYMES
│   │   ├── [  66]  processed
│   │   │   ├── [2.7M]  data.pt
│   │   │   ├── [ 431]  pre_filter.pt
│   │   │   └── [ 431]  pre_transform.pt
│   │   └── [ 178]  raw
│   │       ├── [864K]  ENZYMES_A.txt
│   │       ├── [ 73K]  ENZYMES_graph_indicator.txt
│   │       ├── [1.2K]  ENZYMES_graph_labels.txt
│   │       ├── [3.7M]  ENZYMES_node_attributes.txt
│   │       ├── [ 38K]  ENZYMES_node_labels.txt
│   │       └── [2.5K]  README.txt
│   └── [ 166]  ENZYMES.zip
└── [1.5K]  main.py

现在就可以运行代码了

from torch_geometric.datasets import Planetoid
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

dataset = Planetoid(root='./data', name='Cora')
print(dataset[0].y.shape)
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam([
	dict(params=model.conv1.parameters(), weight_decay=5e-4),
    dict(params=model.conv2.parameters(), weight_decay=0)
    ], lr=0.01)

model.train()
for epoch in range(1000):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    if epoch%10==9:
        model.eval()
        logits, accs = model(data), []
        for _, mask in data('train_mask', 'val_mask', 'test_mask'):
            pred = logits[mask].max(1)[1]
            acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
            accs.append(acc)
        log = 'Epoch: {:03d}, Train: {:.5f}, Val: {:.5f}, Test: {:.5f}'
        print(log.format(epoch+1, accs[0], accs[1], accs[2]))
        

运行结果:

torch.Size([2708])
Epoch: 010, Train: 0.97143, Val: 0.68800, Test: 0.72700
Epoch: 020, Train: 0.99286, Val: 0.77200, Test: 0.78000
Epoch: 030, Train: 1.00000, Val: 0.76800, Test: 0.77700
Epoch: 040, Train: 1.00000, Val: 0.76800, Test: 0.77800
Epoch: 050, Train: 1.00000, Val: 0.76800, Test: 0.78600
Epoch: 060, Train: 1.00000, Val: 0.77000, Test: 0.79000
Epoch: 070, Train: 1.00000, Val: 0.77400, Test: 0.79600
Epoch: 080, Train: 1.00000, Val: 0.77400, Test: 0.79600
Epoch: 090, Train: 1.00000, Val: 0.77400, Test: 0.79500
Epoch: 100, Train: 1.00000, Val: 0.77400, Test: 0.79400
Epoch: 110, Train: 1.00000, Val: 0.77200, Test: 0.79500
Epoch: 120, Train: 1.00000, Val: 0.77000, Test: 0.79500
Epoch: 130, Train: 1.00000, Val: 0.76800, Test: 0.79400
Epoch: 140, Train: 1.00000, Val: 0.76600, Test: 0.79400
Epoch: 150, Train: 1.00000, Val: 0.76400, Test: 0.79700
Epoch: 160, Train: 1.00000, Val: 0.76400, Test: 0.79700
Epoch: 170, Train: 1.00000, Val: 0.76200, Test: 0.80100
Epoch: 180, Train: 1.00000, Val: 0.76400, Test: 0.80200
Epoch: 190, Train: 1.00000, Val: 0.76400, Test: 0.80300
Epoch: 200, Train: 1.00000, Val: 0.76600, Test: 0.80300
Epoch: 210, Train: 1.00000, Val: 0.76600, Test: 0.80400
Epoch: 220, Train: 1.00000, Val: 0.76600, Test: 0.80400
Epoch: 230, Train: 1.00000, Val: 0.76400, Test: 0.80400
Epoch: 240, Train: 1.00000, Val: 0.76600, Test: 0.80400
Epoch: 250, Train: 1.00000, Val: 0.76000, Test: 0.80400
Epoch: 260, Train: 1.00000, Val: 0.76200, Test: 0.80500
Epoch: 270, Train: 1.00000, Val: 0.76600, Test: 0.80700
Epoch: 280, Train: 1.00000, Val: 0.76800, Test: 0.80500
Epoch: 290, Train: 1.00000, Val: 0.77000, Test: 0.80500
Epoch: 300, Train: 1.00000, Val: 0.77000, Test: 0.80500
Epoch: 310, Train: 1.00000, Val: 0.77200, Test: 0.80700
Epoch: 320, Train: 1.00000, Val: 0.77400, Test: 0.80400
Epoch: 330, Train: 1.00000, Val: 0.77400, Test: 0.80400
Epoch: 340, Train: 1.00000, Val: 0.77400, Test: 0.80400
Epoch: 350, Train: 1.00000, Val: 0.77400, Test: 0.80600
Epoch: 360, Train: 1.00000, Val: 0.77400, Test: 0.80700
Epoch: 370, Train: 1.00000, Val: 0.77400, Test: 0.80600
Epoch: 380, Train: 1.00000, Val: 0.77600, Test: 0.80700
Epoch: 390, Train: 1.00000, Val: 0.77600, Test: 0.80600
Epoch: 400, Train: 1.00000, Val: 0.77400, Test: 0.80600
Epoch: 410, Train: 1.00000, Val: 0.77400, Test: 0.80600
Epoch: 420, Train: 1.00000, Val: 0.77400, Test: 0.80700
Epoch: 430, Train: 1.00000, Val: 0.77400, Test: 0.80800
Epoch: 440, Train: 1.00000, Val: 0.77600, Test: 0.80800
Epoch: 450, Train: 1.00000, Val: 0.77600, Test: 0.80800
Epoch: 460, Train: 1.00000, Val: 0.77400, Test: 0.80800
Epoch: 470, Train: 1.00000, Val: 0.77600, Test: 0.80800
Epoch: 480, Train: 1.00000, Val: 0.77400, Test: 0.80900
Epoch: 490, Train: 1.00000, Val: 0.77600, Test: 0.80900
Epoch: 500, Train: 1.00000, Val: 0.77600, Test: 0.81000
Epoch: 510, Train: 1.00000, Val: 0.77600, Test: 0.81000
Epoch: 520, Train: 1.00000, Val: 0.77400, Test: 0.81200
Epoch: 530, Train: 1.00000, Val: 0.77600, Test: 0.81200
Epoch: 540, Train: 1.00000, Val: 0.77600, Test: 0.81200
Epoch: 550, Train: 1.00000, Val: 0.77600, Test: 0.81100
Epoch: 560, Train: 1.00000, Val: 0.77800, Test: 0.81100
Epoch: 570, Train: 1.00000, Val: 0.77600, Test: 0.81100
Epoch: 580, Train: 1.00000, Val: 0.77600, Test: 0.81200
Epoch: 590, Train: 1.00000, Val: 0.77600, Test: 0.81200
Epoch: 600, Train: 1.00000, Val: 0.77600, Test: 0.81200
Epoch: 610, Train: 1.00000, Val: 0.77600, Test: 0.81200
Epoch: 620, Train: 1.00000, Val: 0.77600, Test: 0.81200
Epoch: 630, Train: 1.00000, Val: 0.77600, Test: 0.81200
Epoch: 640, Train: 1.00000, Val: 0.77600, Test: 0.81400
Epoch: 650, Train: 1.00000, Val: 0.77600, Test: 0.81400
Epoch: 660, Train: 1.00000, Val: 0.77600, Test: 0.81300
Epoch: 670, Train: 1.00000, Val: 0.77600, Test: 0.81300
Epoch: 680, Train: 1.00000, Val: 0.77600, Test: 0.81300
Epoch: 690, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 700, Train: 1.00000, Val: 0.77600, Test: 0.81300
Epoch: 710, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 720, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 730, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 740, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 750, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 760, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 770, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 780, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 790, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 800, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 810, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 820, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 830, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 840, Train: 1.00000, Val: 0.77800, Test: 0.81200
Epoch: 850, Train: 1.00000, Val: 0.77800, Test: 0.81300
Epoch: 860, Train: 1.00000, Val: 0.77800, Test: 0.81200
Epoch: 870, Train: 1.00000, Val: 0.77800, Test: 0.81200
Epoch: 880, Train: 1.00000, Val: 0.77800, Test: 0.81200
Epoch: 890, Train: 1.00000, Val: 0.77800, Test: 0.81200
Epoch: 900, Train: 1.00000, Val: 0.77800, Test: 0.81200
Epoch: 910, Train: 1.00000, Val: 0.77800, Test: 0.81100
Epoch: 920, Train: 1.00000, Val: 0.77800, Test: 0.81100
Epoch: 930, Train: 1.00000, Val: 0.77800, Test: 0.81100
Epoch: 940, Train: 1.00000, Val: 0.77800, Test: 0.81100
Epoch: 950, Train: 1.00000, Val: 0.77800, Test: 0.81100
Epoch: 960, Train: 1.00000, Val: 0.77800, Test: 0.81100
Epoch: 970, Train: 1.00000, Val: 0.77800, Test: 0.81100
Epoch: 980, Train: 1.00000, Val: 0.77800, Test: 0.81100
Epoch: 990, Train: 1.00000, Val: 0.77800, Test: 0.81100
Epoch: 1000, Train: 1.00000, Val: 0.77800, Test: 0.81100

注意:代码是参考的https://github.com/rusty1s/pytorch_geometric/tree/master/examples

你可能感兴趣的:(深度学习)