首先看一下这个函数的接口,需要三个输入:1)维度dim 2)索引数组index 3)原数组src,为了方便理解,我们后文把src换成input表示。最终的输出是新的output数组。
下面依次介绍:
1)维度dim:整数,可以是0,1,2,3…
2)索引数组index:索引数组是一个tensor,其中的数据类型是整数,表示位置
3)原数组input:也是一个tensor,其中的数据类型任意
先说一下这个函数是干嘛的,在我看来,这个scatter函数就是把input数组中的数据进行重新分配。index中表示了要把原数组中的数据分配到output数组中的位置,如果未指定,则填充0。
比如说下面这段代码:
import torch
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)
运行结果如下:
tensor([[-0.2558, -1.8930, -0.7831, 0.6100],
[ 0.3246, 2.1289, 0.5887, 1.5588]])
tensor([[ 0.6100, -1.8930, -0.7831, -0.2558, 0.0000],
[ 0.5887, 0.3246, 2.1289, 1.5588, 0.0000]])
下面,我详细说一下为什么会是这样的结果。
前面说了,scatter是input数组,根据index数组,对input数组中的数据进行重新分配,我们看一下分配过程是怎样的。
input:
tensor([[-0.2558, -1.8930, -0.7831, 0.6100],
[ 0.3246, 2.1289, 0.5887, 1.5588]])
index:
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output:
tensor([[ 0.6100, -1.8930, -0.7831, -0.2558, 0.0000],
[ 0.5887, 0.3246, 2.1289, 1.5588, 0.0000]])
首先,对input[0][0]进行重分配。符号 -> 代表赋值。由于scatter方法的第一维dim=1,所以input数组中的数据只是在第1维上进行重新分配,第0维不变。以二维数组举例,第一行的数据重新分配后一定在还是第一行,不能跑到第二行。
input[0][0] -> output[0][index[0][0]] = output[0][3]
数据位置发生的变化都是在第1维上,第0维不变。
input[0][1] -> output[0][index[0][1]] = output[0][1]
input[0][2] -> output[0][index[0][2]] = output[0][2]
input[0][3] -> output[0][index[0][3]] = output[0][0]
为了方便理解,我们是按照input中数据的顺序索引的,但是在pytorch中,是根据从index[0][0]到index[0][3]这样的顺序去索引的,索引的input位置和output的位置必须要存在,否则会提示错误。但是,不一定所有的input数据都会分到output中,output也不是所有位置都有对应的input,当output中没有对应的input时,自动填充0。
一般scatter用于生成onehot向量,如下所示:
index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)
输出结果是:
tensor([[0., 1., 0., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 1.]])
如果input是一个数字的话,代表这用于分配到output的数字是多少。
import torch
tensorB = torch.tensor([[2.5880, 2.1556, -31.0650, -13.5238, 11.0284],
[-0.2982, 10.8633, -22.4874, -9.2778, -1.1321]])
tensorA = torch.tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]])
index = torch.tensor([
[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]
])
tensorC = tensorA.scatter_(0, index, tensorB) # dim=0: 按列填充
print('tensorC = ', tensorC)
# tensorC = tensor([[ 2.5880, 10.8633, -22.4874, -13.5238, 11.0284, 0.0000],
# [ 0.0000, 2.1556, 0.0000, -9.2778, 0.0000, 0.0000],
# [ -0.2982, 0.0000, -31.0650, 0.0000, -1.1321, 0.0000]])
tensorD = tensorA.scatter_(1, index, tensorB) # dim=1: 按行填充
print('tensorD = ', tensorD)
# tensorD = tensor([[ 11.0284, 2.1556, -31.0650, -13.5238, 11.0284, 0.0000],
# [-22.4874, -9.2778, -1.1321, -9.2778, 0.0000, 0.0000],
# [ -0.2982, 0.0000, -31.0650, 0.0000, -1.1321, 0.0000]])
参考资料:
官方TORCH.TENSOR.SCATTER_
pytorch中torch.Tensor.scatter用法
one hot编码:torch.Tensor.scatter_()
函数用法详解