Google在2017发表的论文Neural Message Passing for Quantum Chemistry中提到的Message Passing Neural Networks机制成为了后来图机器学习计算的标准范式实现。
在PyG图机器学习库中,MessagePassing类实现了上述机制,并作为所有图卷积层的基类,该机制最重要的公式如下:
该机制过程主要有下面三个函数保证:
消息传递,message函数
消息聚合,aggregate函数
节点更新,update函数
MessagePassing类中的propagate方法,会依次调用message,aggregate,和update方法,完成消息的传递,聚合,和更新,如果子类图卷积层重写了message,aggregate和update函数,则实例调用的是子类方法。
下面将对MessagePassing类进行源码分析,但是主要从两个问题进行展开,这样可以更加地聚焦源码的实现过程。
Q1:GCN,SAGEConv等图卷积层需要继承MessagePassing类,然后重写消息传递(message),聚合(aggregate),和更新(update)函数,一般情况下,子类中自定义的这些函数会引入很多未知的自定义入参,那么这些自定义的入参是如何在统一的MessagePassing框架下被调用的呢?
答案就是MessagePassing中的Inspector类。
MessagePassing基类的构造函数中初始化了一个非常重要的类Inspector,这个类实现了对子类中自定义的message,aggregate,message_and_aggregate,以及update函数的入参提取。
self.inspector = Inspector(self)
self.inspector.inspect(self.message)
self.inspector.inspect(self.aggregate, pop_first=True)
self.inspector.inspect(self.message_and_aggregate, pop_first=True)
self.inspector.inspect(self.update, pop_first=True)
self.__user_args__ = self.inspector.keys(
['message', 'aggregate', 'update']).difference(self.special_args)
下面看一下这个Inspector类的实现:
import re
import inspect
from collections import OrderedDict
from typing import Dict, List, Any, Optional, Callable, Set
from .typing import parse_types
class Inspector(object):
def __init__(self, base_class: Any):
self.base_class: Any = base_class
self.params: Dict[str, Dict[str, Any]] = {}
def inspect(self, func: Callable,
pop_first: bool = False) -> Dict[str, Any]:
## 注册func函数的入参,并建立func与入参之间的对应关系
params = inspect.signature(func).parameters
params = OrderedDict(params)
if pop_first:
params.popitem(last=False)
self.params[func.__name__] = params
def keys(self, func_names: Optional[List[str]] = None) -> Set[str]:
keys = []
for func in func_names or list(self.params.keys()):
keys += self.params[func].keys()
return set(keys)
def __implements__(self, cls, func_name: str) -> bool:
if cls.__name__ == 'MessagePassing':
return False
if func_name in cls.__dict__.keys():
return True
return any(self.__implements__(c, func_name) for c in cls.__bases__)
def implements(self, func_name: str) -> bool:
return self.__implements__(self.base_class.__class__, func_name)
def types(self, func_names: Optional[List[str]] = None) -> Dict[str, str]:
out: Dict[str, str] = {}
for func_name in func_names or list(self.params.keys()):
func = getattr(self.base_class, func_name)
arg_types = parse_types(func)[0][0]
for key in self.params[func_name].keys():
if key in out and out[key] != arg_types[key]:
raise ValueError(
(f'Found inconsistent types for argument {key}. '
f'Expected type {out[key]} but found type '
f'{arg_types[key]}.'))
out[key] = arg_types[key]
return out
def distribute(self, func_name, kwargs: Dict[str, Any]):
## 通过给定的函数func_name,从kwargs获取其入参值
out = {}
for key, param in self.params[func_name].items():
data = kwargs.get(key, inspect.Parameter.empty)
if data is inspect.Parameter.empty:
if param.default is inspect.Parameter.empty:
raise TypeError(f'Required parameter {key} is empty.')
data = param.default
out[key] = data
return out
这个类的实现借助了python的inspect模块(可以参考一下python官网关于该模块的解释,inspect --- 检查对象 — Python 3.7.12 文档),可以检查子类自定义函数的参数,重点看一下这个类的inspect函数,该函数借助inspect.signature(),获取了子类的函数入参,比如当func="message"时,params = inspect.signature('message').parameters就会获得子类自定义message函数的参数,注意这里一定是子类的message函数,最后self.params建立函数名与其入参的字典,方便后面distribute函数的调用。
def inspect(self, func: Callable,
pop_first: bool = False) -> Dict[str, Any]:
params = inspect.signature(func).parameters # 获取子类func函数的函数入参
params = OrderedDict(params)
if pop_first:
params.popitem(last=False) ## 去除掉self这个入参
self.params[func.__name__] = params ## 建立函数与其入参的字典。
Q2:上述的Inspector类实现了message,aggregate,message_and_aggregate,以及update函数与其各自入参的对应关系映射,但是一般的图卷积层是通过的forward函数进行调用的,通常的调用顺序如下,那么是如何将自定义的参数kwargs与后续的函数的入参进行对应的呢?
注:这里的子类是自定义图卷积层,父类指的是MessagePassing类,如果子类实现的对应的函数,那么调用的就是子类的方法,否则调用父类的方法。
可以看出,Inspector类解决了橙色部分的函数参数绑定问题,但是参数是从forward传递进来的,如何在propagate中将参数对应到传递后面的函数中呢?这部分的参数对应关系主要由MessagePassing类的__collect__函数进行参数收集和数据赋值。
这里有个重要的原则:就是对于后续函数需要接受的_i和_j类的参数需要从kwargs中按照一定的规则进行提取,对于其他的参数都是直接透传。
def __collect__(self, args, edge_index, size, kwargs):
i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
out = {}
for arg in args:# 遍历自定义函数中的参数
if arg[-2:] not in ['_i', '_j']: # 不包含_i和_j的自定义参数直接透传
out[arg] = kwargs.get(arg, Parameter.empty) # 从用户传递进来的kwargs参数中获取值
else:
dim = 0 if arg[-2:] == '_j' else 1 # 注意这里的取值维度
data = kwargs.get(arg[:-2], Parameter.empty) # 取用户传递进来的kwargs前缀arg[:-2]的数据
if isinstance(data, (tuple, list)):
assert len(data) == 2
if isinstance(data[1 - dim], Tensor):
self.__set_size__(size, 1 - dim, data[1 - dim])
data = data[dim]
if isinstance(data, Tensor):
self.__set_size__(size, dim, data)
data = self.__lift__(data, edge_index,
j if arg[-2:] == '_j' else i)
out[arg] = data
if isinstance(edge_index, Tensor):
out['adj_t'] = None
out['edge_index'] = edge_index
out['edge_index_i'] = edge_index[i]
out['edge_index_j'] = edge_index[j]
out['ptr'] = None
elif isinstance(edge_index, SparseTensor):
out['adj_t'] = edge_index
out['edge_index'] = None
out['edge_index_i'] = edge_index.storage.row()
out['edge_index_j'] = edge_index.storage.col()
out['ptr'] = edge_index.storage.rowptr()
out['edge_weight'] = edge_index.storage.value()
out['edge_attr'] = edge_index.storage.value()
out['edge_type'] = edge_index.storage.value()
out['index'] = out['edge_index_i']
out['size'] = size
out['size_i'] = size[1] or size[0]
out['size_j'] = size[0] or size[1]
out['dim_size'] = out['size_i']
return out
coll_dict = self.__collect__(self.__user_args__, edge_index, size,
kwargs)
首先需要明确,会有一些固定的图神经网络参数,主要包括如下:
special_args: Set[str] = { 'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j', 'ptr', 'index', 'dim_size' }
__collect__函数中的args主要对应子类中相关函数(message,aggregate,update等)的自定义参数self.__user_args__,kwargs为子类的forward函数中调用propagate传递进来的参数。可以仔细阅读__collect__函数的代码可以发现,self.__user_args__中_i和_j后缀的参数非常重要,这里有一个约定就是,i表示与target节点相关的参数,j表示source节点相关的参数,其图上的指向为j->i for j 属于N(i),后缀不包含_i和_j的参数直接被透传。 举个gat_conv的例子,gat_conv的forward函数中调用propagate函数:
out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size)
其,传递的参数为x,alpha为自定义的,而gat_conv的message函数定义为:
def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, index, ptr, size_i)
self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return x_j * alpha.unsqueeze(-1)
经过__collect__函数的处理后,其对应关系如下:
x_j = x_l
alpha_j = alpha_l
alpha_i = alpha_r
最后,propagate中依次从coll_dict
中获取与message,aggregate,update函数的参数进行调用。注意这里获取的参数是通过上述的self.inspector.distribute
函数进行获取的。
def propagate(self,..):
##...
##...
msg_kwargs = self.inspector.distribute('message', coll_dict)
out = self.message(**msg_kwargs)
##...
##...
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)
【图神经网络】PyG的MessagePassing机制实现 - 知乎
Creating Message Passing Networks — pytorch_geometric 2.0.1 documentation