DGL中NN模块的构造函数


上图引用自:dgl用户文档第三章(nn模块编写)

"""

构造函数完成以下几个任务:
1、设置选项。
2、注册可学习的参数或者子模块。
3、初始化参数。

"""
import torch.nn as nn
from dgl.utils import expand_as_pair
import dgl.nn
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
"""
在构造函数中,用户首先需要设置数据的维度。
对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。 
对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。

除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。
对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。 常用的聚合类型
包括 mean、 sum、 max 和 min。一些模块可能会使用更加复杂的聚合函数,
比如 lstm。
"""
"""注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linear、 nn.LSTM 等。"""

class SAGE(nn.Module):
    def __init__(self, in_feats, out_feats, aggregator_type,
                 bias=True, norm=None, activation=None):
        super(SAGE, self).__init__()
        # 获取源节点和目标节点的输入特征维度
        self._in_src_feats, self._in_dest_feats = expand_as_pair(in_feats)
        # 输出特征维度
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation
        # 聚合类型:mean、pool、lstm、gcn
        if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        self.reset_parameters()
    #  构造函数的最后调用了 reset_parameters() 进行权重初始化。
    def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)# 上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: hv=hv/∥hv∥2
    def forward(self, graph, feat):    #SAGEConv示例中的 forward() 函数
        # 输入图对象的规范检测
        with graph.local_scope():
            # 指定图类型,然后根据图类型扩展输入特征
            feat_src, feat_dst = expand_as_pair(feat, graph)
        # 消息传递和聚合
        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst
            graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
            # 除以入度
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

        # GraphSAGE中gcn聚合不需要fc_self
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        # 更新特征作为输出
        # 激活函数
        if self.activation is not None:
            rst = self.activation(rst)
        # 归一化
        if self.norm is not None:
            rst = self.norm(rst)
        return rst
"""
在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,
DGL NN模块额外增加了1个参数 :class:dgl.DGLGraph。forward() 函数的内容一般可以分为3项操作:
1、检测输入图对象是否符合规范。
2、消息传递和聚合。
3、聚合后,更新特征作为输出。

forward() 函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。 
比如在 GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。 当1个节点入
度为0时, mailbox 将为空,并且聚合函数的输出值全为0, 这可能会导致模型性能不佳。但是
,在 SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来, forward() 函数的输
出不会全为0。在这种情况下,无需进行此类检验。
DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图和子图块。



聚合部分的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息
传递均使用 update_all() API和 DGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的
消息传递代码 里所介绍的性能优化。

聚合后,更新特征作为输出
forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。
"""


你可能感兴趣的:(DGL,深度学习,人工智能,DGL,GNN,图卷积神经网络,异构图,GCN)