官方代码
# 1. 一般训练gnn,需要构建以下5个数据:
1. graph: 根据边构建的dgl graph
2. labels: 每个节点对应的label
3. train_idx: 训练数据对应的节点index
4. valid_idx: 验证数据对应的节点index
5. node_feat: 每个节点对应的特征向量
# 2. 用dgl的 sampler 和 dataloader, 实现节点采样
fanouts = [20, 20] # 每层每个节点采样数量
train_sampler = MultiLayerNeighborSampler(fanouts)
train_dataloader = NodeDataLoader(graph,
train_idx,
train_sampler,
device=device,
use_ddp=False,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)
# 和gcn类似, 构建多层卷积注意输入和输出的维度即可
class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size):
super().__init__()
self.layers = nn.ModuleList()
# three-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean'))
self.dropout = nn.Dropout(0.5)
self.hid_size = hid_size
self.out_size = out_size
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
for step, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
batch_inputs = node_feat[input_nodes].to(device)
batch_labels = labels[output_nodes].to(device)
blocks = [block.to(device) for block in blocks]
train_batch_logits = model(blocks, batch_inputs)
train_loss = loss_fn(train_batch_logits, batch_labels)
# train_dataloader 返回的数据格式说明:
1. 假设batch_size=64, 卷积2层, 每层节点采样10个(graph训练一般batch上千)
2. 那么就是64个节点, 每个采样10个, 有些节点邻居没有10个, 假设第一层共采样了125个
3. 同理第二层在125个基础上采样, 共采样了298个
4. input_nodes: 就是最终298个节点的索引,
5. output_nodes:就是一开始64个节点的索引
6. blocks: 记录了采样边的关系, 298->125的关系, 125->64的关系
7. 模型输入是把 298个节点卷积成125个, 然后125个再卷积成64个节点。
8. 因此输出是64个节点的向量表示
Epoch:009|batch:0000, val_loss:1.617950, val_acc:0.8281