scatter_
(dim, index, src) → Tensor
首先看一下这个函数的接口,需要三个输入:1)维度dim; 2)索引数组index; 3)原数组src,为了方便理解,我们后面把src换成input表示。最终的输出是新的output数组。即 scatter_(dim, index, input) → Tensor
下面依次介绍:
1)维度dim:整数,可以是0,1,2,3...
2)索引数组index:索引数组是一个tensor,其中的数据类型是整数,表示位置
3)原数组input:也是一个tensor,其中的数据类型任意
先说一下这个函数是干sa的,在我看来,这个scatter_函数就是把input数组中的数据进行重新分配。index中的数表示了要把原数组中的数据分配到output数组中的位置,如果未指定,则填充为0。
注意scatter_
函数是inplace操作。
比如说下面这段代码:
import torch input = torch.randn(2, 4) print(input) output = torch.zeros(2, 5) index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]]) output = output.scatter(1, index, input) print(output)
运行结果如下:
下面,详细说一下为什么会是这样的结果。
前面说了,scatter是input数组根据index数组对input数组中的数据进行重新分配,我们看一下分配过程是怎样的。
input:
tensor([[-0.9817, -2.3192, 0.1756, 1.2975],
[ 0.8049, 0.4067, -0.0477, -0.3837]])
index:
torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output:
tensor([[ 1.2975, -2.3192, 0.1756, -0.9817, 0.0000],
[-0.0477, 0.8049, 0.4067, -0.3837, 0.0000]])
首先,对input[0][0]进行重分配。
NOTE: 符号 -> 代表赋值。由于scatter方法的第一维dim=1,所以input数组中的数据只是在第1维上进行重新分配,第0维不变。以二维数组举例,第一行的数据重新分配后一定在还是第一行,不能跑到第二行。
input[0][0] -> output[0][index[0][0]] = output[0][3]
数据位置发生的变化都是在第1维上,第0维不变。
input[0][1] -> output[0][index[0][1]] = output[0][1]
input[0][2] -> output[0][index[0][2]] = output[0][2]
input[0][3] -> output[0][index[0][3]] = output[0][0]
为了方便理解,我是按照input中数据的顺序索引的,但是在pytorch中,是根据从index[0][0]到index[0][3]这样的顺序去索引的,索引的input位置和output的位置必须要存在,否则会提示错误。但是,不一定所有的input数据都会分到output中,output也不是所有位置都有对应的input,当output中没有对应的input时,自动填充0。