【Torch_scatter API】关于scatter_add函数的用法

关于scatter_add函数的用法

两种scatter函数的关系

torch_scatter
torch_scatter是pytorch_geometric作者基于pytorch做的small extension library of highly optimized sparse update (scatter and segment) operations
scatter_add_
是pytorch中实现的函数,上述函数很多是基于此所作,只不过当前函数侧重于矩阵的计算,而前者侧重于图相关的计算

scatter_add_

文字解释

scatter_add_scatter的一个例子,pytorch对scatter函数的解释如下:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
>>>self.scatter_(dim, index, src, reduce)
  • dim即维度,是对于self而言的,即在self的哪一dim进行操作
  • index是索引,即要在self的哪一index进行操作  index的维度可以小于等于src,如果二者维度相同,则相当于将src的每一个数字都加到self的对应index上;如果index维度小,例如src: shape[5,3], index: shape[3,2]则代表只有src[:3,:2]的数字参与了操作
  • src是待操作的源数字,比较好理解
  • reduce代表操作的方式,none代表直接赋值,add则是+=,multiply是*= 因此scatter的意思就是 将src中前index部分的数字以一定的方式scatter(散布)到self中

以代码和对应图像为例对上述进行解释

src = torch.arange(1, 11).reshape((2, 5))
src
>>>tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
>>>tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

【Torch_scatter API】关于scatter_add函数的用法_第1张图片

index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
>>>tensor([[4, 2, 3, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]])

 【Torch_scatter API】关于scatter_add函数的用法_第2张图片

 图中相同颜色的填充代表同一个位置,线条颜色则代表数字的分配,dim指对应self的维度;index分别应该是self的[dim,[index]],对应src待操作的数字应该是src[:index.shape]

总结

按照上述图示来看,需要注意的几点就是index的数值和维度分别对应的是self和src的取值

scatter_add

有了上述的理解,对于torch_scatter中的scatter_add更好理解了

src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0,1,2,0,3],[0,1,1,2,2]])
torch_scatter.scatter_add(src, index)
>>>
tensor([[ 5,  2,  3,  5],
        [ 6, 15, 19,  0]])

torch_scatter.scatter_add(src, index, dim=0)
>>>
tensor([[ 7,  0,  0,  4,  0],
        [ 0,  9,  8,  0,  0],
        [ 0,  0,  3,  9, 10],
        [ 0,  0,  0,  0,  5]])

需要注意的几点:

  • dim默认为-1
  • index的值代表的是输出的维度,比如最大为100则输出的dim对应的维度为101
  • 源码中开始会做一个broadcast将维度扩展

对于此函数,主要知道其应用场景:

scatter_add(edge_weight, edge_index[1], dim=0)

其意义就是将每个target node的与其邻接节点的边的权重之求和,最终得到的输出维度是节点数目;如果weight是0或者1,则得到的是degree,如果选择的是target节点则是入度,否则是出度。

你可能感兴趣的:(语法,基础知识,pytorch,深度学习,python)