函数 tensor.scatter_(dim, index, src)
返回值:返回一个根据index映射关系映射后的新的tensor
参数解释:dim 变化的维度
index 映射关系
src 输入的tensor
代码示例:
import torch
x = torch.FloatTensor([[ 1, 2, 3, 4,5],
[6, 7,8, 9,10]])
result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
result.scatter_(dim = 0, index = indices, src = x)
输出结果:
tensor([[ 1., 7., 8., 4., 5.],
[ 0., 2., 0., 9., 0.],
[ 6., 0., 3., 0., 10.]])
代码解释:
result.scatter_(dim = 0, index = indices, src = x)
dim=0:
那么转换过程中,只改src中各个元素的行,不改变列。从最后的结果可以看到:
x中第一列 1和6,转换后输出结果中 1和6依然在第一列,只是行发生了变化。换言之,行根据index参数的映射关系进行了映射。那么怎么理解index的映射关系呢?
indices = torch.tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
既然列不变,不妨我们纵向来看:
index参数中的 第0行第0列是0,那么就把src中的[0][0] =1,映射到 [0][0] = 1
index参数中的 第1行第0列是2,那么就把src中的[1][0] =6,映射到 [2][0] = 6
index第0列已经全部映射完毕,但是第0列还有一个[1][0]是空的,那就自动赋值为0
于是,结果中的 第一列 分别为 1 0 6
那如果dim = 1 会发生什么呢?先上代码
代码示例:
import torch
x = torch.FloatTensor([[ 1, 2, 3, 4,5],
[6, 7,8, 9,10]])
result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
result.scatter_(dim = 1, index = indices, src = x)
输出结果:
tensor([[ 5., 2., 3., 0., 0.],
[ 8., 9., 10., 0., 0.],
[ 0., 0., 0., 0., 0.]])
再分析一次:
dim = 1,即映射关系是列映射关系。那么行不变。我们横向分析,以第0行为例:
index参数中的 第0行第0列是0,那么就把src中的[0][0] =1,映射到 [0][0] = 1
index参数中的 第0行第1列是1,那么就把src中的[0][1] =2,映射到 [0][1] =2
index参数中的 第0行第2列是2,那么就把src中的[0][2] =3,映射到 [0][2] = 3
index参数中的 第0行第3列是0,那么就把src中的[0][3] =4,映射到 [0][0] =4
index参数中的 第0行第4列是0,那么就把src中的[0][4] =5,映射到 [0][0] =5
index第0列已经全部映射完毕,未赋值的,自动赋0
于是,结果中的 第1行结果为 5 2 3 0 0
那最后一行为什么都是0呢?
因为dim = 1,要求只改变第二维度,行维度不变,index中没有第2行的映射关系。也就是说,再映射过程中,源数据src的元素在哪一行,映射结果中就在哪一行。源数据中第二行没有元素,那映射结果中的第二行只能赋0。
注意:本文中下标都是0开始。
作者:LambAI https://www.bilibili.com/read/cv18787354?spm_id_from=333.999.list.card_article.click 出处:bilibili