MessagePassing是torch_geometric中GNN模型的基类,实现了下面的消息传递公式
要继承这个类,需要复写三个函数:
propagate(edge_index, size=None)
message()
消息传递分两种方式,默认的是source_to_target
update()
其中propagate在执行的过程中会调用message和update
。。。
#source=>target的消息传播
out = self.message(*message_args)
#out为source顶点,out的shape为[E,channel],其中E为边的条数,channel为顶点embedding的维度
out = scatter_(self.aggr, out, edge_index[i], dim, dim_size=size[i])
#将关联边的信息加(默认‘add’)到target的顶点上,out的shape为[V,channel],其中V为target顶点的个数
out = self.update(out, *update_args)
return out
假设顶点V1和顶点v2,v3,v4,.....vn有边相连,propagate做的事情是将v2,v3,v4,.....vn的信息加(默认‘add’,也可以‘mean’,‘max’)到v1上。
GCN的实现,三个函数都是在MessagePassing的基础上实现的。
唯一关键的一步是norm函数,根据GCN的信息传播的公式,计算邻接矩阵和对角度矩阵。
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels, improved=False, cached=False,
bias=True, **kwargs):
super(GCNConv, self).__init__(aggr='add', **kwargs)
#略
@staticmethod
def norm(edge_index, num_nodes, edge_weight=None, improved=False,
dtype=None):
#略
#最关键的只有这一步,计算邻接矩阵和对角度矩阵,根据GCN的信息传播的公式
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_weight=None):
""""""
x = torch.matmul(x, self.weight)
#略去代码 主要是设置是否缓存
edge_index, norm = self.cached_result
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
if self.bias is not None:
aggr_out = aggr_out + self.bias
return aggr_out