关于torch.scatter函数

A.scatter_(dim, index, B) # 基本用法, tensor A 被就地scatter到 tensor B

看完一头懵

然后知乎上找了个图torch.scatter_直观理解官网示例 - 知乎

关于torch.scatter函数_第1张图片

刚开始还是一头懵,后来发现是这个样子的。dim=0表示按行放置,源tensor的第0行第0列元素(也就是0.3992)放在新tensor的第0行(因为index是0), 源tensor的第0行第1列元素(0.2908)放在新tensor的第1行(因为index是1),源tensor的第0行第3列元素(0.9044)放在新tensor的第2行(因为index是2)...以此类推。

dim 和 index

这两个参数是配套的。index和源tensor维度一致(也可以为空,就不改变目标tensor),对于n-D tensor,dim可以为0~N-1。

你可能感兴趣的:(pytorch,pytorch)