官方文档
scatter
import torch
from torch_scatter import scatter
index = torch.tensor([0,0,1,0,2])
input = torch.tensor([[1,1],[1,1],[2,2],[1,1],[1,1]])
result = scatter(input,index,dim=0,reduce="sum")
"""
tensor([[3, 3],
[2, 2],
[1, 1]])
"""
index = | 0 | 0 | 1 | 0 | 2 |
---|---|---|---|---|---|
input = | [1,1] | [1,1] | [2,2] | [1,1] | [1,1] |
index = 0
的 input有 :[1,1]
[1,1]
[1,1]
,sum为[3,3]
index = 1
的 input有 :[2,2]
,sum为[2,2]
index = 2
的 input有 :[1,1]
,sum为[1,1]
input
的shape
为[5,2]
,由于函数中dim=0
,而index
有3个不用的值,index
所以将5换成3.result的形状应为[5 3,2]
故:
result[0] = [3,3]
result[1] = [2,2]
result[2] = [1,1]
index = torch.tensor([0,0,1])
input = torch.tensor([[1,1,1],[1,1,2],[2,2,3],[21,10,9]])
scatter(input,index,dim=1,reduce="sum")
"""
tensor([[ 2, 1],
[ 2, 2],
[ 4, 3],
[31, 9]])
"""
index = 0
的 input有 : [ 1 , 1 , 2 , 21 ] T [1,1,2,21]^T [1,1,2,21]T [ 1 , 1 , 2 , 10 ] T [1,1,2,10]^T [1,1,2,10]T ,sum为 [ 2 , 2 , 4 , 31 ] T [2,2,4,31]^T [2,2,4,31]T
index = 1
的 input有 : [ 1 , 2 , 3 , 9 ] T [1,2,3,9]^T [1,2,3,9]T,sum为 [ 1 , 2 , 3 , 9 ] T [1,2,3,9]^T [1,2,3,9]T
sum
求和mul
乘法mean
平均min
最小max
最大