torch_scatter
torch_scatter是pytorch_geometric作者基于pytorch做的small extension library of highly optimized sparse update (scatter and segment) operations
scatter_add_
是pytorch中实现的函数,上述函数很多是基于此所作,只不过当前函数侧重于矩阵的计算,而前者侧重于图相关的计算
scatter_add_
是scatter
的一个例子,pytorch对scatter函数的解释如下:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
>>>self.scatter_(dim, index, src, reduce)
src = torch.arange(1, 11).reshape((2, 5))
src
>>>tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
>>>tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
>>>tensor([[4, 2, 3, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
图中相同颜色的填充代表同一个位置,线条颜色则代表数字的分配,dim指对应self的维度;index分别应该是self的[dim,[index]],对应src待操作的数字应该是src[:index.shape]
按照上述图示来看,需要注意的几点就是index的数值和维度分别对应的是self和src的取值
有了上述的理解,对于torch_scatter中的scatter_add更好理解了
src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0,1,2,0,3],[0,1,1,2,2]])
torch_scatter.scatter_add(src, index)
>>>
tensor([[ 5, 2, 3, 5],
[ 6, 15, 19, 0]])
torch_scatter.scatter_add(src, index, dim=0)
>>>
tensor([[ 7, 0, 0, 4, 0],
[ 0, 9, 8, 0, 0],
[ 0, 0, 3, 9, 10],
[ 0, 0, 0, 0, 5]])
需要注意的几点:
对于此函数,主要知道其应用场景:
scatter_add(edge_weight, edge_index[1], dim=0)
其意义就是将每个target node的与其邻接节点的边的权重之求和,最终得到的输出维度是节点数目;如果weight是0或者1,则得到的是degree,如果选择的是target节点则是入度,否则是出度。