torch_scatter.scatter_max函数

scatter_max 这个函数还挺有用的,pointnet见到的

from torch_scatter import scatter_max  

x, _ = scatter_max(x, data.batch, dim=0)

例子

from torch_scatter import scatter_max

import torch 
if __name__ == "__main__":
    t1 = torch.tensor([[0, 2], [2, 2], [3, 4], [7, 8], [3, 5]])
    t2 = torch.tensor([0, 0, 0, 1, 1])
    out, _ = scatter_max(t1, t2, dim=0)
    print(out)

输出结果:

tensor([[3,4],
        [7,8]])

1.dim=0代表查询维度
2.根据 t2 把 t1 的每个元素归组,这里前三个是一组,后俩是一组。
3.将第一组和第二组中最大元素挑出来。
4.输出:第一组 [3, 4] ,第二组 [7, 8] 。

你可能感兴趣的:(pytorch,深度学习,机器学习)