pyg在大图上面进行GNN

今天想在reddit数据集上跑实验,发现pyg默认的sage_conv是不支持邻居采样、batch训练的,需要使用pyg提供的NeighborLoader

pyg中的NeighborLoader

附上一个我自己实现的在reddit上面进行GraphSage的代码:
GraphSage_batch.py:

import torch
from torch_geometric.nn import SAGEConv
from tqdm import tqdm
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class GraphSageNet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = x.relu_()
                x = F.dropout(x, p=0.5, training=self.training)
        return x

    @torch.no_grad()
    def inference(self, x_all, subgraph_loader):
        pbar = tqdm(total=len(subgraph_loader.dataset) * len(self.convs))
        pbar.set_description('Evaluating')

        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch:
        for i, conv in enumerate(self.convs):
            xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id.to(x_all.device)].to(device)
                x = conv(x, batch.edge_index.to(device))
                if i < len(self.convs) - 1:
                    x = x.relu_()
                xs.append(x[:batch.batch_size].cpu())
                pbar.update(batch.batch_size)
            x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all

GraphSage_reddit.py:

import copy

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Reddit

import sys

from torch_geometric.loader import NeighborLoader
from tqdm import tqdm
from sklearn.metrics import f1_score
from GraphSage_batch import GraphSageNet

sys.path.append("..")

from early_stop import EarlyStop

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = Reddit(root='../data/Reddit')

data = dataset[0].to(device, 'x', 'y')

kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
                              num_neighbors=[25, 10], shuffle=True, **kwargs)

subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,
                                 num_neighbors=[-1], shuffle=False, **kwargs)

# No need to maintain these features during evaluation:
del subgraph_loader.data.x, subgraph_loader.data.y
# Add global node index information.
subgraph_loader.data.num_nodes = data.num_nodes
subgraph_loader.data.n_id = torch.arange(data.num_nodes)

model = GraphSageNet(dataset.num_features, 256, dataset.num_classes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)


# 一次epoch的训练
def train(epoch):
    model.train()

    pbar = tqdm(total=int(len(train_loader.dataset)))
    pbar.set_description(f'Epoch {epoch:02d}')

    total_loss = total_examples = 0
    for batch in train_loader:
        optimizer.zero_grad()
        y = batch.y[:batch.batch_size]
        y_hat = model(batch.x, batch.edge_index.to(device))[:batch.batch_size]
        loss = F.cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * batch.batch_size
        total_examples += batch.batch_size
        pbar.update(batch.batch_size)
    pbar.close()

    return total_loss / total_examples


@torch.no_grad()
def test(mask):  # 使用mask来决定是验证机还是测试机,参数值:data.val_mask或者data.test_mask
    model.eval()
    y_hat = model.inference(data.x, subgraph_loader).argmax(dim=-1)
    y = data.y.to(y_hat.device)

    f1 = f1_score(y[mask], y_hat[mask], average='micro')
    return f1


patience = 20
early_stop = EarlyStop(patience)
for epoch in range(100):
    # 在训练集上进行训练
    loss= train(epoch)
    # 在验证集上计算准确率
    f1=test(data.val_mask)
    print(f'epoch:{epoch},loss:{loss},f1:{f1}')

    # early stop
    if not early_stop.step(f1, model):
        model = torch.load('./model/best_model.pth')
        break

f1=test(data.test_mask)
print('f1: {:.4f}'.format(f1))

你可能感兴趣的:(机器学习,graph,ML&&DL,pytorch,深度学习,python)