pytorch:torch_scatter.scatter_max()

torch_scatter.scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None)

在这里插入图片描述

  • 根据index将src分组,求每一组中的最大值输出到out
  • dim是维度
    在这里插入图片描述
from torch_scatter import scatter_max

src = torch.Tensor([[2, 0

你可能感兴趣的:(pytorch)