torch_scatter.scatter详解

scatter方法通过src和index两个张量来获得一个新的张量。

torch_scatter.scatter(src: torch.Tensor, index: torch.Tensor, dim: int = - 1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, reduce: str = 'sum') → torch.Tensor

原理如图,根据index,将index相同值对应的src元素进行对应定义的计算,dim为在第几维进行相应的运算。e.g.scatter_sum即进行sum运算,scatter_mean即进行mean运算。
torch_scatter.scatter详解_第1张图片
e.g.x=scatter_mean(data.x, data.batch, dim=0)
我们给定一个二维张量x[952,21]为src,一维张量batch[952]为index,scatter_mean则将batch相同元素对应的x元素在0维上进行mean计算。具体过程如下:
torch_scatter.scatter详解_第2张图片

你可能感兴趣的:(pytorch,python)