torch_scatter

官方文档

文章目录

      • `scatter`

scatter

用一张官网的图
torch_scatter_第1张图片

import torch
from torch_scatter import scatter

index = torch.tensor([0,0,1,0,2])
input = torch.tensor([[1,1],[1,1],[2,2],[1,1],[1,1]])
result = scatter(input,index,dim=0,reduce="sum")
"""
tensor([[3, 3],
        [2, 2],
        [1, 1]])
"""
index = 0 0 1 0 2
input = [1,1] [1,1] [2,2] [1,1] [1,1]

index = 0 的 input有 :[1,1] [1,1] [1,1] ,sum为[3,3]
index = 1 的 input有 :[2,2] ,sum为[2,2]
index = 2 的 input有 :[1,1],sum为[1,1]

inputshape[5,2],由于函数中dim=0,而index有3个不用的值,index所以将5换成3.result的形状应为[5 3,2]
故:

  • result[0] = [3,3]
  • result[1] = [2,2]
  • result[2] = [1,1]
index = torch.tensor([0,0,1])
input = torch.tensor([[1,1,1],[1,1,2],[2,2,3],[21,10,9]])
scatter(input,index,dim=1,reduce="sum")
"""
tensor([[ 2,  1],
        [ 2,  2],
        [ 4,  3],
        [31,  9]])
"""

index = 0 的 input有 : [ 1 , 1 , 2 , 21 ] T [1,1,2,21]^T [1,1,2,21]T [ 1 , 1 , 2 , 10 ] T [1,1,2,10]^T [1,1,2,10]T ,sum为 [ 2 , 2 , 4 , 31 ] T [2,2,4,31]^T [2,2,4,31]T
index = 1 的 input有 : [ 1 , 2 , 3 , 9 ] T [1,2,3,9]^T [1,2,3,9]T,sum为 [ 1 , 2 , 3 , 9 ] T [1,2,3,9]^T [1,2,3,9]T

torch_scatter_第2张图片
reduce参数可选值有:

  • sum 求和
  • mul 乘法
  • mean 平均
  • min 最小
  • max 最大

你可能感兴趣的:(Python)