【Code】GraphSAGE 源码解析

1.GraphSAGE

本文代码源于 DGL 的 Example 的,感兴趣可以去 github 上面查看。

阅读代码的本意是加深对论文的理解,其次是看下大佬们实现算法的一些方式方法。当然,在阅读 GraphSAGE 代码时我也发现了之前忽视的 GraphSAGE 的细节问题和一些理解错误。比如说:之前忽视了 GraphSAGE 的四种聚合方式的具体实现,对 Alogrithm 2 的算法理解也有问题,再回头看那篇 GraphSAGE 的推文时,实在惨不忍睹= =。

进入正题,简单回顾一下 GraphSAGE。

核心算法:

【Code】GraphSAGE 源码解析_第1张图片

2.SAGEConv

dgl 已经实现了 SAGEConv 层,所以我们可以直接导入。

有了 SAGEConv 层后,GraphSAGE 实现起来就比较简单。

和基于 GraphConv 实现 GCN 的唯一区别在于把 GraphConv 改成了 SAGEConv:

class GraphSAGE(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 aggregator_type):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.g = g
        # input layer
        self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type,
                                    feat_drop=dropout, activation=activation))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type,
                                        feat_drop=dropout, activation=activation))
        # output layer
        self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type,
                                    feat_drop=dropout, activation=None)) # activation None
        
    def forward(self, features):
        h = features
        for layer in self.layers:
            h = layer(self.g, h)
        return h

来看一下 SAGEConv 是如何实现的

SAGEConv 接收七个参数:

  • in_feat:输入的特征大小,可以是一个整型数,也可以是两个整型数。如果用在单向二部图上,则可以用整型数来分别表示源节点和目的节点的特征大小,如果只给一个的话,则默认源节点和目的节点的特征大小一致。需要注意的是,如果参数 aggregator_type 为 gcn 的话,则源节点和目的节点的特征大小必须一致;
  • out_feats:输出特征大小;
  • aggregator_type:聚合类型,目前支持 mean、gcn、pool、lstm,比论文多一个 gcn 聚合,gcn 聚合可以理解为周围所有的邻居结合和当前节点的均值;
  • feat_drop=0.:特征 drop 的概率,默认为 0;
  • bias=True:输出层的 bias,默认为 True;
  • norm=None:归一化,可以选择一个归一化的方式,默认为 None
  • activation=None:激活函数,可以选择一个激活函数去更新节点特征,默认为 None。
class SAGEConv(nn.Module):
    

    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        # expand_as_pair 函数可以返回一个二维元组。
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type != 'gcn':
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        """初始化参数
        这里的 gain 可以从 calculate_gain 中获取针对非线形激活函数的建议的值
        用于初始化参数
        """
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def _lstm_reducer(self, nodes):
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}

    def forward(self, graph, feat):
        """ SAGE 层的前向传播
        接收 DGLGraph 和 Tensor 格式的节点特征
        """
        # local_var 会返回一个作用在内部函数中使用的 Graph 对象
        # 外部数据的变化不会影响到这个 Graph
        # 可以理解为保护数据不被意外修改
        graph = graph.local_var()

        if isinstance(feat, tuple):
            feat_src = self.feat_drop(feat[0])
            feat_dst = self.feat_drop(feat[1])
        else:
            feat_src = feat_dst = self.feat_drop(feat)

        h_self = feat_dst

        # 根据不同的聚合类型选择不同的聚合方式
        # 值得注意的是,论文在 3.3 节只给出了三种聚合方式
        # 而这里却多出来一个 gcn 聚合
        # 具体原因将在后文给出
        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            # check_eq_shape 用于检查源节点和目的节点的特征大小是否一致
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst     # same as above if homogeneous
            graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
            # divide in_degrees
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'lstm':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

        # GraphSAGE GCN does not require fc_self.
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        # activation
        if self.activation is not None:
            rst = self.activation(rst)
        # normalization
        if self.norm is not None:
            rst = self.norm(rst)
        return rst

reset_parameters 函数那里有一个 gain,初始化参数服从 Xavier 均匀分布:

W ∼ U [ − gain 6 n j + n j + 1 , gain gain 6 n j + n j + 1 ] W \sim U[- \frac{\text{gain} \sqrt{6}}{\sqrt{n_j+n_{j+1}}}, \text{gain} \frac{\text{gain} \sqrt{6}}{\sqrt{n_j+n_{j+1}}}] \\ WU[nj+nj+1 gain6 ,gainnj+nj+1 gain6 ]
仔细阅读论文时会发现,在实验部分作者给出了四种方式的聚合方法:

【Code】GraphSAGE 源码解析_第2张图片

配合着论文,我们来阅读下代码

  1. MEAN 聚合器:首先对邻居节点进行均值聚合,然后当前节点特征与邻居节点特征该分别送入全连接网络后相加得到结果,对应伪代码如下:

h N ( v ) k ← MEAN k ( { h u k − 1 , ∀ u ∈ N ( v ) } ) h v k ← σ ( W k ⋅ CONCAT ( { h v k − 1 , h N ( v ) k } ) h_{N(v)}^k \leftarrow \text{MEAN}_k(\{ \mathbf{h}_u^{k-1}, \forall u \in N(v )\}) \\ h_v^k \leftarrow \sigma(\mathbf{W^k} \cdot \text{CONCAT}(\{\mathbf{h}_v^{k-1}, h_{N(v)}^k\} ) \\ hN(v)kMEANk({huk1,uN(v)})hvkσ(WkCONCAT({hvk1,hN(v)k})

对应代码如下:

h_self = feat_dst
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
# 公式里写的是 concat,这里是 element-wise 的和。
# 稍微有些出入,不过问题不大。
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
  1. GCN 聚合:首先对邻居节点的特征和自身节点的特征求均值,得到的聚合特征送入到全连接网络中,对应论文公式如下:

h v k ← σ ( W ⋅ MEAN ( { h v k − 1 } ∪ h u k − 1 , ∀ u ∈ N ( v ) } ) h_v^k \leftarrow \sigma(\mathbf{W} \cdot \text{MEAN}(\{\mathbf{h}_v^{k-1}\} \cup \mathbf{h}_u^{k-1}, \forall u \in N(v )\} ) hvkσ(WMEAN({hvk1}huk1,uN(v)})

对应代码如下:

graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst   
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
# 公式中给出集合并集,这里做 element-wise 的和,问题也不大。
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
rst = self.fc_neigh(h_neigh)

gcn 与 mean 的关键区别在于节点邻居节点和当前节点取平均的方式:gcn 是直接将当前节点和邻居节点取平均,而 mean 聚合是 concat 当前节点的特征和邻居节点的特征,所以前者只经过一个全连接层,而后者是分别经过全连接层

【Code】GraphSAGE 源码解析_第3张图片

这里利用下斯坦福大学的同学实现的 GCN 聚合器的解释,如果不明白的话,可以去其 github 仓库查看源码:

class MeanAggregator(Layer):
    """
    Aggregates via mean followed by matmul and non-linearity.
    """

class GCNAggregator(Layer):
    """
    Aggregates via mean followed by matmul and non-linearity.
    Same matmul parameters are used self vector and neighbor vectors.
    """
  1. POOL 聚合器:池化方法中,每一个节点的向量都会对应一个全连接神经网络,然后基于 elementwise 取最大池化操作。对应公式如下:

AGGREGATE k p o o l = max ( { W p o o l h u i k + b , ∀ u i ∈ N ( v ) } ) \text{AGGREGATE}_k^{pool} = \text{max}( \{\mathbf{W}_{pool} \mathbf{h}_{u_i}^k + \mathbf{b}, \forall u_i \in N(v) \} ) \\ AGGREGATEkpool=max({Wpoolhuik+b,uiN(v)})

对应代码如下:

graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
  1. LSTM 聚合器:其表达能力比 mean 聚合器要强,但是 LSTM 是非对称的,即其考虑节点的顺序性,论文作者通过将节点进行随机排列来调整 LSTM 对无序集的支持。
def _lstm_reducer(self, nodes):
  """LSTM reducer
  """
  m = nodes.mailbox['m'] # (B, L, D)
  batch_size = m.shape[0]
  h = (m.new_zeros((1, batch_size, self._in_src_feats)),
       m.new_zeros((1, batch_size, self._in_src_feats)))
  _, (rst, _) = self.lstm(m, h)
  return {'neigh': rst.squeeze(0)}

graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
h_neigh = graph.dstdata['neigh']

以上便是利用 SAGEConv 实现 GraphSAGE 的方法,剩余训练的内容与前文介绍 GCN 一致,不再进行介绍。

3.Neighbor sampler

这里再介绍一种基于节点邻居采样并利用 minibatch 的方法进行前向传播的实现。

这种方法适用于大图,并且能够并行计算。

import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from dgl.nn import SAGEConv
import time
from dgl.data import RedditDataset
import tqdm

首先是邻居采样(NeighborSampler),这个最好配合着 PinSAGE 的实现来看:

【Code】GraphSAGE 源码解析_第4张图片

我们关注下上半部分,首先对节点 A 的一阶邻居进行采样,然后再进行二阶邻居采样,节点 A 的二阶邻居可能会包括节点 A 及其一阶邻居。

Neighbor Sampler 函数的实现目的与之类似,首先获取最右边的种子节点,然后依次进行一阶采样和二阶采样。采样的方向是从左到右,而特征聚合方向是从从右到左。

【Code】GraphSAGE 源码解析_第5张图片

class NeighborSampler(object):
    def __init__(self, g, fanouts):
        """
        g 为 DGLGraph;
        fanouts 为采样节点的数量,实验使用 10,25,指一阶邻居采样 10 个,二阶邻居采样 25 个。
        """
        self.g = g
        self.fanouts = fanouts

    def sample_blocks(self, seeds):
        seeds = th.LongTensor(np.asarray(seeds))
        blocks = []
        for fanout in self.fanouts: 
            # sample_neighbors 可以对每一个种子的节点进行邻居采样并返回相应的子图
            # replace=True 表示用采样后的邻居节点代替所有邻居节点
            frontier = dgl.sampling.sample_neighbors(g, seeds, fanout, replace=True)
            # 将图转变为可以用于消息传递的二部图(源节点和目的节点)
            # 其中源节点的 id 也可能包含目的节点的 id(原因上面说了)
            # 转变为二部图主要是为了方便进行消息传递
            block = dgl.to_block(frontier, seeds)
            # 获取新图的源节点作为种子节点,为下一层作准备
            # 之所以是从 src 中获取种子节点,是因为采样操作相对于聚合操作来说是一个逆向操作
            seeds = block.srcdata[dgl.NID]
            # 把这一层放在最前面。
            # PS:如果数据量大的话,插入操作是不是不太友好。
            blocks.insert(0, block)
        return blocks

Algorithm 2 伪代码如下所示,NeighborSampler 对应 Algorithm 2 算法的 1-7 步:

【Code】GraphSAGE 源码解析_第6张图片

# GraphSAGE 的代码实现
class GraphSAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        # block 是我们采样获得的二部图,这里用于消息传播
        # x 为节点特征
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h_dst = h[:block.number_of_dst_nodes()]
            h = layer(block, (h, h_dst))
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

    def inference(self, g, x, batch_size, device):
        # inference 用于评估测试,针对的是完全图
        # 目前会出现重复计算的问题,优化方案还在 to do list 上
        nodes = th.arange(g.number_of_nodes())
        for l, layer in enumerate(self.layers):
            y = th.zeros(g.number_of_nodes(), 
                         self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
            for start in tqdm.trange(0, len(nodes), batch_size):
                end = start + batch_size
                batch_nodes = nodes[start:end]
                block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
                input_nodes = block.srcdata[dgl.NID]
                h = x[input_nodes].to(device)
                h_dst = h[:block.number_of_dst_nodes()]
                h = layer(block, (h, h_dst))
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)
                y[start:end] = h.cpu()
            x = y
        return y
def compute_acc(pred, labels):
    """
    计算准确率
    """
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, inputs, labels, val_mask, batch_size, device):
    """
    评估模型,调用 model 的 inference 函数
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
    return compute_acc(pred[val_mask], labels[val_mask])

def load_subtensor(g, labels, seeds, input_nodes, device):
    """
    将一组节点的特征和标签复制到 GPU 上。
    """
    batch_inputs = g.ndata['features'][input_nodes].to(device)
    batch_labels = labels[seeds].to(device)
    return batch_inputs, batch_labels
# 参数设置
gpu = -1
num_epochs = 20
num_hidden = 16
num_layers = 2
fan_out = '10,25'
batch_size = 1000
log_every = 20  # 记录日志的频率
eval_every = 5
lr = 0.003
dropout = 0.5
num_workers = 0  # 用于采样进程的数量

if gpu >= 0:
    device = th.device('cuda:%d' % gpu)
else:
    device = th.device('cpu')

# load reddit data
# NumNodes: 232965
# NumEdges: 114848857
# NumFeats: 602
# NumClasses: 41
# NumTrainingSamples: 153431
# NumValidationSamples: 23831
# NumTestSamples: 55703
data = RedditDataset(self_loop=True)
train_mask = data.train_mask
val_mask = data.val_mask
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features

开始训练:

train_nid = th.LongTensor(np.nonzero(train_mask)[0])
val_nid = th.LongTensor(np.nonzero(val_mask)[0])
train_mask = th.BoolTensor(train_mask)
val_mask = th.BoolTensor(val_mask)

# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in fan_out.split(',')])

# Create PyTorch DataLoader for constructing blocks
# collate_fn 参数指定了 sampler,可以对 batch 中的节点进行采样
dataloader = DataLoader(
    dataset=train_nid.numpy(),
    batch_size=batch_size,
    collate_fn=sampler.sample_blocks,
    shuffle=True,
    drop_last=False,
    num_workers=num_workers)

# Define model and optimizer
model = GraphSAGE(in_feats, num_hidden, n_classes, num_layers, F.relu, dropout)
model = model.to(device)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
avg = 0
iter_tput = []
for epoch in range(num_epochs):
    tic = time.time()

    for step, blocks in enumerate(dataloader):
        tic_step = time.time()

        input_nodes = blocks[0].srcdata[dgl.NID]
        seeds = blocks[-1].dstdata[dgl.NID]

        # Load the input features as well as output labels
        batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, device)

        # Compute loss and prediction
        batch_pred = model(blocks, batch_inputs)
        loss = loss_fcn(batch_pred, batch_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter_tput.append(len(seeds) / (time.time() - tic_step))
        if step % log_every == 0:
            acc = compute_acc(batch_pred, batch_labels)
            gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
            print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format(
                epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc))

    toc = time.time()
    print('Epoch Time(s): {:.4f}'.format(toc - tic))
    if epoch >= 5:
        avg += toc - tic
    if epoch % eval_every == 0 and epoch != 0:
        eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, batch_size, device)
        print('Eval Acc {:.4f}'.format(eval_acc))

print('Avg epoch time: {}'.format(avg / (epoch - 4)))

4.Reference

  1. Github:dmlc/dgl
  2. williamleif/GraphSAGE

关注公众号跟踪最新内容:阿泽的学习笔记

阿泽的学习笔记

你可能感兴趣的:(GNN,人工智能,GCN,GNN,图神经网络,图卷积神经网络,深度学习)