torch.scatter算子详解

0 scatter理解

关于该算子, torch 官方的文档https://pytorch-cn.readthedocs.io/zh/latest/package_references/Tensor/#scatter_input-dim-index-src-tensor 是这么解释的:
torch.scatter算子详解_第1张图片
刚开始看了几次, 挺费解的。仔细理解之后, 发现其实用法也挺简单的。 这个操作的作用就是把src这个Tensor的值给更新到input这个Tensor中。 那更新到哪些位置呢, 就是由index 和dim去确定了。

拿上面的例子来说, 输入input是一个维度为[3,5]的Tensor, src是一个维度为[2,5]的Tensor。 那么很自然的, 要把src中这10个值更新到input中, 显然需要10个位置索引去决定更新到哪些位置。 那么如何去表示这10个索引呢, 这里用index 和dim两个值共同确定。 首先dim=0, 表示index是按第0维度(也就是行)去索引input的。

index的第一行是[0,1,2,0,0] 就表示把src中的第一行的5个值分别更新到input的第0行, 第1行, 第2行, 第0行和第0行。 列是相同的, 也就是src中第1列对应的也是input中的第一列。 其实把index写完整就更好理解了,完整的index应该是[[0,0],[1,1],[2,2],[0,3],[0,4]]. 因为列是相同的, 所以省去了列值, 这导致有一些不好理解。 其实如果写出完整的index也就不需要dim这个参数了。 之所以不用完整的索引值, 而是用不完全的index和dim共同确定最终的index, 应该是为了简化index的写法。

index的第一行是[2,0,0,1,2] 就表示把src中的第一行的5个值分别更新到input的第2行, 第0行, 第0行, 第1行和第2行。完整的index应该是[[2,0],[0,1],[0,2],[1,3],[2,4]].

下面的图非常直观的表示了这一过程:
torch.scatter算子详解_第2张图片

1 作用

上面花了较大的篇幅介绍了scatter的具体作用,看着还挺复杂的。 那么这个操作到底有什么用呢? 实际上, 这个操作基本上都用在one-hot的操作中。 one-hot的操作中, 就需要用到这个操作, 把索引指向位置的值更新为1.

你可能感兴趣的:(基础知识,pytorch,深度学习,python)