用DGL实现一个简单的GCN-cora的例子。
参考:https://docs.dgl.ai/tutorials/blitz/1_introduction.html#sphx-glr-tutorials-blitz-1-introduction-py
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
Using backend: pytorch
import dgl.data
dataset = dgl.data.CoraGraphDataset()
print(dataset)
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done loading data from cached files.
g = dataset[0]
print(g)
Graph(num_nodes=2708, num_edges=10556,
ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}
edata_schemes={})
print(g.ndata)
print(g.edata)
{'feat': tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]), 'label': tensor([3, 4, 4, ..., 3, 3, 3]), 'val_mask': tensor([False, False, False, ..., False, False, False]), 'test_mask': tensor([False, False, False, ..., True, True, True]), 'train_mask': tensor([ True, True, True, ..., False, False, False])}
{}
print(g.ndata['feat'].shape)
print(g.ndata['label'].shape)
torch.Size([2708, 1433])
torch.Size([2708])
from dgl.nn import GraphConv
class GCN(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(GCN, self).__init__()
self.conv1 = GraphConv(in_feats, h_feats)
self.conv2 = GraphConv(h_feats, num_classes)
def forward(self, g, in_feat):
x = F.relu(self.conv1(g, in_feat))
x = F.softmax(self.conv2(g, x))
return x
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
print(model)
GCN(
(conv1): GraphConv(in=1433, out=16, normalization=both, activation=None)
(conv2): GraphConv(in=16, out=7, normalization=both, activation=None)
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)
g = g.to(device)
cpu
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
feat = g.ndata['feat']
label = g.ndata['label']
def train():
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
out = model(g, feat)
loss = criterion(out[train_mask], label[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = out.argmax(dim=1)
train_acc = (pred[train_mask] == label[train_mask]).float().mean()
val_acc = (pred[val_mask] == label[val_mask]).float().mean()
test_acc = (pred[test_mask] == label[test_mask]).float().mean()
return loss.item(), train_acc, val_acc, test_acc
def main():
best_val_acc = 0
best_test_acc = 0
for epoch in range(100):
loss, train_acc, val_acc, test_acc = train()
if best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
print('epoch:{:03d}, train_acc:{:.4f}, val_acc:{:.4f}, test_acc:{:.4f}'.format(epoch, train_acc, val_acc, test_acc))
print('best_val_acc:', best_val_acc)
print('best_test_acc:', best_test_acc)
if __name__ == '__main__':
main()
epoch:000, train_acc:0.1357, val_acc:0.1220, test_acc:0.1130
epoch:001, train_acc:0.2143, val_acc:0.1960, test_acc:0.1920
epoch:002, train_acc:0.6857, val_acc:0.4620, test_acc:0.4530
epoch:003, train_acc:0.5143, val_acc:0.2640, test_acc:0.2750
epoch:004, train_acc:0.5000, val_acc:0.2760, test_acc:0.2900
epoch:005, train_acc:0.5500, val_acc:0.3300, test_acc:0.3360
epoch:006, train_acc:0.6143, val_acc:0.3200, test_acc:0.3310
epoch:007, train_acc:0.6357, val_acc:0.3260, test_acc:0.3390
epoch:008, train_acc:0.6143, val_acc:0.3420, test_acc:0.3530
epoch:009, train_acc:0.6500, val_acc:0.3420, test_acc:0.3610
epoch:010, train_acc:0.6357, val_acc:0.3260, test_acc:0.3470
epoch:011, train_acc:0.6429, val_acc:0.3160, test_acc:0.3390
epoch:012, train_acc:0.6571, val_acc:0.3320, test_acc:0.3430
epoch:013, train_acc:0.6500, val_acc:0.3220, test_acc:0.3460
epoch:014, train_acc:0.6357, val_acc:0.3220, test_acc:0.3410
epoch:015, train_acc:0.6571, val_acc:0.3340, test_acc:0.3480
epoch:016, train_acc:0.6500, val_acc:0.3320, test_acc:0.3480
epoch:017, train_acc:0.6571, val_acc:0.3400, test_acc:0.3540
epoch:018, train_acc:0.6714, val_acc:0.3460, test_acc:0.3530
epoch:019, train_acc:0.6929, val_acc:0.3520, test_acc:0.3640
epoch:020, train_acc:0.6929, val_acc:0.3560, test_acc:0.3720
epoch:021, train_acc:0.6929, val_acc:0.3580, test_acc:0.3720
epoch:022, train_acc:0.7071, val_acc:0.3740, test_acc:0.3820
epoch:023, train_acc:0.7071, val_acc:0.3700, test_acc:0.3870
epoch:024, train_acc:0.7071, val_acc:0.3800, test_acc:0.3910
epoch:025, train_acc:0.7071, val_acc:0.3880, test_acc:0.4000
epoch:026, train_acc:0.7286, val_acc:0.3820, test_acc:0.4070
epoch:027, train_acc:0.7643, val_acc:0.4200, test_acc:0.4240
epoch:028, train_acc:0.7857, val_acc:0.4240, test_acc:0.4280
epoch:029, train_acc:0.8000, val_acc:0.4500, test_acc:0.4540
epoch:030, train_acc:0.8214, val_acc:0.4680, test_acc:0.4690
epoch:031, train_acc:0.8286, val_acc:0.4820, test_acc:0.4810
epoch:032, train_acc:0.8286, val_acc:0.4920, test_acc:0.4830
epoch:033, train_acc:0.8286, val_acc:0.5060, test_acc:0.4980
epoch:034, train_acc:0.8286, val_acc:0.5080, test_acc:0.5090
epoch:035, train_acc:0.8357, val_acc:0.5120, test_acc:0.5170
epoch:036, train_acc:0.8429, val_acc:0.5120, test_acc:0.5210
epoch:037, train_acc:0.8571, val_acc:0.5140, test_acc:0.5310
epoch:038, train_acc:0.8714, val_acc:0.5240, test_acc:0.5420
epoch:039, train_acc:0.8714, val_acc:0.5340, test_acc:0.5450
epoch:040, train_acc:0.8929, val_acc:0.5480, test_acc:0.5480
epoch:041, train_acc:0.9071, val_acc:0.5640, test_acc:0.5630
epoch:042, train_acc:0.9071, val_acc:0.5640, test_acc:0.5680
epoch:043, train_acc:0.9214, val_acc:0.5780, test_acc:0.5830
epoch:044, train_acc:0.9286, val_acc:0.5840, test_acc:0.5890
epoch:045, train_acc:0.9429, val_acc:0.6080, test_acc:0.6110
epoch:046, train_acc:0.9429, val_acc:0.6120, test_acc:0.6150
epoch:047, train_acc:0.9429, val_acc:0.6360, test_acc:0.6450
epoch:048, train_acc:0.9429, val_acc:0.6500, test_acc:0.6490
epoch:049, train_acc:0.9429, val_acc:0.6600, test_acc:0.6610
epoch:050, train_acc:0.9571, val_acc:0.6760, test_acc:0.6670
epoch:051, train_acc:0.9571, val_acc:0.6820, test_acc:0.6820
epoch:052, train_acc:0.9643, val_acc:0.7000, test_acc:0.7000
epoch:053, train_acc:0.9643, val_acc:0.7020, test_acc:0.7000
epoch:054, train_acc:0.9643, val_acc:0.7140, test_acc:0.7100
epoch:055, train_acc:0.9643, val_acc:0.7240, test_acc:0.7230
epoch:056, train_acc:0.9643, val_acc:0.7260, test_acc:0.7220
epoch:057, train_acc:0.9643, val_acc:0.7460, test_acc:0.7400
epoch:058, train_acc:0.9643, val_acc:0.7360, test_acc:0.7370
epoch:059, train_acc:0.9643, val_acc:0.7540, test_acc:0.7400
epoch:060, train_acc:0.9643, val_acc:0.7560, test_acc:0.7500
epoch:061, train_acc:0.9643, val_acc:0.7600, test_acc:0.7510
epoch:062, train_acc:0.9714, val_acc:0.7580, test_acc:0.7520
epoch:063, train_acc:0.9714, val_acc:0.7620, test_acc:0.7560
epoch:064, train_acc:0.9714, val_acc:0.7620, test_acc:0.7550
epoch:065, train_acc:0.9714, val_acc:0.7660, test_acc:0.7550
epoch:066, train_acc:0.9714, val_acc:0.7720, test_acc:0.7650
epoch:067, train_acc:0.9786, val_acc:0.7700, test_acc:0.7590
epoch:068, train_acc:0.9714, val_acc:0.7800, test_acc:0.7750
epoch:069, train_acc:0.9786, val_acc:0.7740, test_acc:0.7660
epoch:070, train_acc:0.9714, val_acc:0.7740, test_acc:0.7630
epoch:071, train_acc:0.9786, val_acc:0.7780, test_acc:0.7730
epoch:072, train_acc:0.9786, val_acc:0.7720, test_acc:0.7660
epoch:073, train_acc:0.9857, val_acc:0.7820, test_acc:0.7820
epoch:074, train_acc:0.9786, val_acc:0.7800, test_acc:0.7700
epoch:075, train_acc:0.9857, val_acc:0.7900, test_acc:0.7830
epoch:076, train_acc:0.9786, val_acc:0.7780, test_acc:0.7730
epoch:077, train_acc:0.9929, val_acc:0.7840, test_acc:0.7770
epoch:078, train_acc:0.9786, val_acc:0.7780, test_acc:0.7760
epoch:079, train_acc:0.9929, val_acc:0.7920, test_acc:0.7800
epoch:080, train_acc:0.9857, val_acc:0.7820, test_acc:0.7750
epoch:081, train_acc:0.9929, val_acc:0.7920, test_acc:0.7840
epoch:082, train_acc:0.9929, val_acc:0.7820, test_acc:0.7770
epoch:083, train_acc:0.9929, val_acc:0.7960, test_acc:0.7870
epoch:084, train_acc:0.9929, val_acc:0.7860, test_acc:0.7760
epoch:085, train_acc:0.9929, val_acc:0.7940, test_acc:0.7860
epoch:086, train_acc:0.9929, val_acc:0.7860, test_acc:0.7780
epoch:087, train_acc:0.9929, val_acc:0.7940, test_acc:0.7910
epoch:088, train_acc:0.9929, val_acc:0.7820, test_acc:0.7800
epoch:089, train_acc:0.9929, val_acc:0.7920, test_acc:0.7890
epoch:090, train_acc:0.9929, val_acc:0.7880, test_acc:0.7780
epoch:091, train_acc:0.9929, val_acc:0.7900, test_acc:0.7870
epoch:092, train_acc:0.9929, val_acc:0.7820, test_acc:0.7780
epoch:093, train_acc:0.9929, val_acc:0.7860, test_acc:0.7870
epoch:094, train_acc:0.9929, val_acc:0.7800, test_acc:0.7770
epoch:095, train_acc:0.9929, val_acc:0.7880, test_acc:0.7880
epoch:096, train_acc:0.9929, val_acc:0.7860, test_acc:0.7830
epoch:097, train_acc:0.9929, val_acc:0.7880, test_acc:0.7870
epoch:098, train_acc:0.9929, val_acc:0.7820, test_acc:0.7830
epoch:099, train_acc:0.9929, val_acc:0.7920, test_acc:0.7920
best_val_acc: tensor(0.7960)
best_test_acc: tensor(0.7870)