在pytorch中,scatter是一个非常实用的映射函数,其将一个源张量(src)中的值按照指定的轴方向(dim)和对应的位置关系(index)逐个填充到目标张量(target)中,其函数写法为:
target.scatter(dim, index, src)
其中各变量及参数的说明如下:
target
:即目标张量,将在该张量上进行映射src
:即源张量,将把该张量上的元素逐个映射到目标张量上dim
:指定轴方向,定义了填充方式。对于二维张量,dim=0
表示逐列进行行填充,而dim=1
表示逐行进行列填充index
: 按照轴方向,在target
张量中需要填充的位置dim 0:
把a按顺序(一行一行遍历)给b的索引(index)赋值,index是行编号
实际就是把a的行按照新的顺序赋值给b,顺序就是index行编号。
列子1:
import torch
a = (torch.arange(10) + 1).reshape(5, 2).float()
print(a)
print("-------------------------------------")
b = torch.zeros(5, 3)
b_ = b.scatter(dim=0, index=torch.LongTensor([[4, 2], [3, 0], [2, 0], [1, 0], [0, 0]]), src=a)
print(b_)
结果:
tensor([[ 1., 2.],
[ 3., 4.],
[ 5., 6.],
[ 7., 8.],
[ 9., 10.]])
-------------------------------------
tensor([[ 9., 10., 0.],
[ 7., 0., 0.],
[ 5., 2., 0.],
[ 3., 0., 0.],
[ 1., 0., 0.]])
例子2:
import torch
a = (torch.arange(10)+1).reshape(2,5).float()
print(a)
print("-------------------------------------")
b = torch.zeros(3, 5)
b_= b.scatter(dim=0, index=torch.LongTensor([[0, 2]]),src=a)
# b 0行 a 1列 1行
# b 2行 a 2列 1行
print(b_)
print("-------------------------------------")
b_= b.scatter(dim=0, index=torch.LongTensor([[0, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)
# 7是因为两个 第2行 第4列 值发生覆盖了。
# 第1行 第1列
# 第3行 第2列
# 第2行 第3列
# 第2行 第4列
# 第3行 第5列
# 第3行 第1列
# 第1行 第2列
# 第3行 第3列
# 第2行 第4列
# 第1行 第5列
print(b_)
结果:
tensor([[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.]])
-------------------------------------
tensor([[1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 2., 0., 0., 0.]])
-------------------------------------
tensor([[ 1., 7., 0., 0., 10.],
[ 0., 0., 3., 9., 0.],
[ 6., 2., 8., 0., 5.]])
dim1:
把a按顺序(一行一行遍历)给b的索引(index)赋值,index是列编号
实际就是把a的行重新设置,赋值到b的行上,新的位置,就是index索引(列编号位置)。
import torch
a = (torch.arange(10)+1).reshape(2,5).float()
print(a)
print("-------------------------------------")
b = torch.zeros(3, 5)
b_= b.scatter(dim=1, index=torch.LongTensor([[0, 2]]),src=a)
#b 0列 第1行, a 0列 第1行
#b 2列 第1行 ,a 1列 第1行
print(b_)
print("-------------------------------------")
b_= b.scatter(dim=1, index=torch.LongTensor([[0, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)
#把a的第1行按顺序 放在b的第1行上,顺序是index
#4的来源:
0, 2, 1, [1], 2
#把a的第2行按顺序 放在b的第2行上,顺序是index
print(b_)
结果:
tensor([[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.]])
-------------------------------------
tensor([[1., 0., 2., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
-------------------------------------
tensor([[ 1., 5., 2., 0., 0.],
[10., 9., 8., 0., 0.],
[ 0., 0., 0., 0., 0.]])