【图神经网络】PyG的MessagePassing机制实现

概述

Google在2017发表的论文Neural Message Passing for Quantum Chemistry中提到的Message Passing Neural Networks机制成为了后来图机器学习计算的标准范式实现。

在PyG图机器学习库中,MessagePassing类实现了上述机制,并作为所有图卷积层的基类,该机制最重要的公式如下:

该机制过程主要有下面三个函数保证:

  • 消息传递,message函数

  • 消息聚合,aggregate函数

  • 节点更新,update函数

MessagePassing类中的propagate方法,会依次调用message,aggregate,和update方法,完成消息的传递,聚合,和更新,如果子类图卷积层重写了message,aggregate和update函数,则实例调用的是子类方法。

源码分析

下面将对MessagePassing类进行源码分析,但是主要从两个问题进行展开,这样可以更加地聚焦源码的实现过程。

Q1:GCN,SAGEConv等图卷积层需要继承MessagePassing类,然后重写消息传递(message),聚合(aggregate),和更新(update)函数,一般情况下,子类中自定义的这些函数会引入很多未知的自定义入参,那么这些自定义的入参是如何在统一的MessagePassing框架下被调用的呢?

答案就是MessagePassing中的Inspector类。

MessagePassing基类的构造函数中初始化了一个非常重要的类Inspector,这个类实现了对子类中自定义的message,aggregate,message_and_aggregate,以及update函数的入参提取。

self.inspector = Inspector(self)
self.inspector.inspect(self.message)
self.inspector.inspect(self.aggregate, pop_first=True)
self.inspector.inspect(self.message_and_aggregate, pop_first=True)
self.inspector.inspect(self.update, pop_first=True)
self.__user_args__ = self.inspector.keys(
    ['message', 'aggregate', 'update']).difference(self.special_args)

下面看一下这个Inspector类的实现:

import re
import inspect
from collections import OrderedDict
from typing import Dict, List, Any, Optional, Callable, Set

from .typing import parse_types


class Inspector(object):
    def __init__(self, base_class: Any):
        self.base_class: Any = base_class
        self.params: Dict[str, Dict[str, Any]] = {}

    def inspect(self, func: Callable,
                pop_first: bool = False) -> Dict[str, Any]:
        ## 注册func函数的入参,并建立func与入参之间的对应关系
        params = inspect.signature(func).parameters
        params = OrderedDict(params)
        if pop_first:
            params.popitem(last=False)
        self.params[func.__name__] = params

    def keys(self, func_names: Optional[List[str]] = None) -> Set[str]:
        keys = []
        for func in func_names or list(self.params.keys()):
            keys += self.params[func].keys()
        return set(keys)

    def __implements__(self, cls, func_name: str) -> bool:
        if cls.__name__ == 'MessagePassing':
            return False
        if func_name in cls.__dict__.keys():
            return True
        return any(self.__implements__(c, func_name) for c in cls.__bases__)

    def implements(self, func_name: str) -> bool:
        return self.__implements__(self.base_class.__class__, func_name)

    def types(self, func_names: Optional[List[str]] = None) -> Dict[str, str]:
        out: Dict[str, str] = {}
        for func_name in func_names or list(self.params.keys()):
            func = getattr(self.base_class, func_name)
            arg_types = parse_types(func)[0][0]
            for key in self.params[func_name].keys():
                if key in out and out[key] != arg_types[key]:
                    raise ValueError(
                        (f'Found inconsistent types for argument {key}. '
                         f'Expected type {out[key]} but found type '
                         f'{arg_types[key]}.'))
                out[key] = arg_types[key]
        return out

    def distribute(self, func_name, kwargs: Dict[str, Any]):
        ## 通过给定的函数func_name,从kwargs获取其入参值
        out = {}
        for key, param in self.params[func_name].items():
            data = kwargs.get(key, inspect.Parameter.empty)
            if data is inspect.Parameter.empty:
                if param.default is inspect.Parameter.empty:
                    raise TypeError(f'Required parameter {key} is empty.')
                data = param.default
            out[key] = data
        return out

这个类的实现借助了python的inspect模块(可以参考一下python官网关于该模块的解释,inspect --- 检查对象 — Python 3.7.12 文档),可以检查子类自定义函数的参数,重点看一下这个类的inspect函数,该函数借助inspect.signature(),获取了子类的函数入参,比如当func="message"时,params = inspect.signature('message').parameters就会获得子类自定义message函数的参数,注意这里一定是子类的message函数,最后self.params建立函数名与其入参的字典,方便后面distribute函数的调用。

def inspect(self, func: Callable,
            pop_first: bool = False) -> Dict[str, Any]:
    params = inspect.signature(func).parameters # 获取子类func函数的函数入参
    params = OrderedDict(params)
    if pop_first:
        params.popitem(last=False) ## 去除掉self这个入参
    self.params[func.__name__] = params ## 建立函数与其入参的字典。

Q2:上述的Inspector类实现了message,aggregate,message_and_aggregate,以及update函数与其各自入参的对应关系映射,但是一般的图卷积层是通过的forward函数进行调用的,通常的调用顺序如下,那么是如何将自定义的参数kwargs与后续的函数的入参进行对应的呢?

【图神经网络】PyG的MessagePassing机制实现_第1张图片

注:这里的子类是自定义图卷积层,父类指的是MessagePassing类,如果子类实现的对应的函数,那么调用的就是子类的方法,否则调用父类的方法。

可以看出,Inspector类解决了橙色部分的函数参数绑定问题,但是参数是从forward传递进来的,如何在propagate中将参数对应到传递后面的函数中呢?这部分的参数对应关系主要由MessagePassing类的__collect__函数进行参数收集和数据赋值。

这里有个重要的原则:就是对于后续函数需要接受的_i和_j类的参数需要从kwargs中按照一定的规则进行提取,对于其他的参数都是直接透传。

def __collect__(self, args, edge_index, size, kwargs):
    i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)

    out = {}
    for arg in args:# 遍历自定义函数中的参数
        if arg[-2:] not in ['_i', '_j']: # 不包含_i和_j的自定义参数直接透传
            out[arg] = kwargs.get(arg, Parameter.empty) # 从用户传递进来的kwargs参数中获取值
        else:
            dim = 0 if arg[-2:] == '_j' else 1 # 注意这里的取值维度
            data = kwargs.get(arg[:-2], Parameter.empty) # 取用户传递进来的kwargs前缀arg[:-2]的数据

            if isinstance(data, (tuple, list)):
                assert len(data) == 2
                if isinstance(data[1 - dim], Tensor):
                    self.__set_size__(size, 1 - dim, data[1 - dim])
                data = data[dim]

            if isinstance(data, Tensor):
                self.__set_size__(size, dim, data)
                data = self.__lift__(data, edge_index,
                                     j if arg[-2:] == '_j' else i)

            out[arg] = data

    if isinstance(edge_index, Tensor):
        out['adj_t'] = None
        out['edge_index'] = edge_index
        out['edge_index_i'] = edge_index[i]
        out['edge_index_j'] = edge_index[j]
        out['ptr'] = None
    elif isinstance(edge_index, SparseTensor):
        out['adj_t'] = edge_index
        out['edge_index'] = None
        out['edge_index_i'] = edge_index.storage.row()
        out['edge_index_j'] = edge_index.storage.col()
        out['ptr'] = edge_index.storage.rowptr()
        out['edge_weight'] = edge_index.storage.value()
        out['edge_attr'] = edge_index.storage.value()
        out['edge_type'] = edge_index.storage.value()

    out['index'] = out['edge_index_i']
    out['size'] = size
    out['size_i'] = size[1] or size[0]
    out['size_j'] = size[0] or size[1]
    out['dim_size'] = out['size_i']

    return out

coll_dict = self.__collect__(self.__user_args__, edge_index, size,
                             kwargs)

首先需要明确,会有一些固定的图神经网络参数,主要包括如下:

special_args: Set[str] = { 'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j', 'ptr', 'index', 'dim_size' }

__collect__函数中的args主要对应子类中相关函数(message,aggregate,update等)的自定义参数self.__user_args__,kwargs为子类的forward函数中调用propagate传递进来的参数。可以仔细阅读__collect__函数的代码可以发现,self.__user_args__中_i和_j后缀的参数非常重要,这里有一个约定就是,i表示与target节点相关的参数,j表示source节点相关的参数,其图上的指向为j->i for j 属于N(i),后缀不包含_i和_j的参数直接被透传。 举个gat_conv的例子,gat_conv的forward函数中调用propagate函数:

out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size)

其,传递的参数为x,alpha为自定义的,而gat_conv的message函数定义为:

def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
            index: Tensor, ptr: OptTensor,
            size_i: Optional[int]) -> Tensor:
    alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
    alpha = F.leaky_relu(alpha, self.negative_slope)
    alpha = softmax(alpha, index, ptr, size_i)
    self._alpha = alpha
    alpha = F.dropout(alpha, p=self.dropout, training=self.training)
    return x_j * alpha.unsqueeze(-1)

经过__collect__函数的处理后,其对应关系如下:

x_j = x_l
alpha_j = alpha_l
alpha_i = alpha_r

最后,propagate中依次从coll_dict中获取与message,aggregate,update函数的参数进行调用。注意这里获取的参数是通过上述的self.inspector.distribute函数进行获取的。

def propagate(self,..):
    ##...
    ##...
    msg_kwargs = self.inspector.distribute('message', coll_dict)
    out = self.message(**msg_kwargs)
    ##...
    ##...
    aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
    out = self.aggregate(out, **aggr_kwargs)
    
    update_kwargs = self.inspector.distribute('update', coll_dict)
	return self.update(out, **update_kwargs)

参考文章

​​​​​​【图神经网络】PyG的MessagePassing机制实现 - 知乎

Creating Message Passing Networks — pytorch_geometric 2.0.1 documentation

你可能感兴趣的:(图机器学习,图神经网络,图机器学习)