在torch._C._TensorBase.py
中,定义了scatter_(self, dim, index, src, reduce=None) -> Tensor
方法,作用是将src
的值写入index
指定的self
相关位置中。用一个三维张量举例如下,将src
在坐标(i,j,k)
下的所有值,写入self
的相应位置,而self
的位置坐标除了dim
维度用index[i,j,k]
代替以外,都不变:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0,用index[i][j][k]替换i坐标
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1,用index[i][j][k]替换j坐标
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2,用index[i][j][k]替换k坐标
要求:
self
,index
,src
必须有相同的维数;index
在任意维度的size
必须小于等于self
和src
对应维度的size
self
和index
中元素的类型必须一致,dtype
>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
"""
理解一下:
self是一个shape为(3,5)的全零tensor;
index是一个shape为(2,5)的tensor;
x同index的shape相同,不相同也可。
dim=0,意味着index需要修改第0维坐标;
原始坐标为:00,01,02,03,04;10,11,12,13,14
更新的横坐标依次为:01200;20012
更新的纵坐标依次为:01234;01234
对应组合,更新坐标为:00,11,22,03,04;20,01,02,13,24
然后用x在原始坐标下的值填写到self更新后的坐标位置,将原始坐标和更新坐标对应来看。
具体来看:
x new_self
00 00
01 11
02 22
03 03
04 04
10 20
11 01
12 02
13 10
14 24
"""
图示上述例子:
>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000, 0.0000, 1.2300, 0.0000],
[ 0.0000, 0.0000, 0.0000, 1.2300]])
"""
理解一下:一个2*1的index_tensor(一个2维张量,两个维度的size分别是2和1,对应两个值为2和3),dim=1,需要修改的就是1维。
原来的坐标是00,10;修改后的坐标是02,13。
然后用目标值1.23去替换self中坐标02,13的值,得到上述结果。
"""
>>> z = torch.ones(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23, reduce='multiply')
>>> z
tensor([[1.0000, 1.0000, 1.2300, 1.0000],
[1.0000, 1.0000, 1.0000, 1.2300]])
"""
同上:用目标值找到self在更新坐标位置的值,乘以目标值1.23得到更新后的矩阵。
"""
类似于上述方法,在python中还包括scatter_add(dim, index, src) -> Tensor
用于实现将src
按照index
位置累加到self
上。
分为以下几个步骤:
raw_index
中;dim
和index
修改原始坐标,得到新的坐标index_pos
;self_tensor
在index_pos
位置的值要累加上other_tensor
在raw_index
位置的值import torch
import numpy as np
from torch import Tensor
"""
@overload
def scatter_add(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
@overload
def scatter_add(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: ...
def scatter_add_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
对pytorch中的scatter_add函数的理解和简单测试:
# 参数:tensor,dim,index,tensor
# 返回:tensor
# 功能:将other_tensor的值累加到self_tensor的相应位置,用index_tensor对应位置的值替换掉self_tensor下标的dim维
# 举例:
self_tensor = [[1, 2], [3, 4]] shape=(2,2)
other_tensor = [[5, 6], [7, 8]] shape=(2,2)
index_tensor = [[0, 0], [1, 1]] shape=(2,2)
dim = 1
以上三个tensor的shape必须一致,下标为:[0,0] [0,1] [1,0] [1,1]
dim=1,那么,self_tensor的第1维下标由index_tensor表示,[0,0] [0,0] [1,1] [1,1]
则:
self_tensor[0,0] = 1 + 5 + 6 = 12
self_tensor[0,1] = 2
self_tensor[1,0] = 3
self_tensor[1,1] = 4 + 7 + 8 = 19
"""
def scatter_add(input_tensor: torch.Tensor, dim: int, index: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
# tensor的维数是不确定的,因此无法用for循环的方式
# 如果tensor是2维,那么dim=0或1,两层for循环,用other对self进行填充
# 如果tensor是3维,那么dim=0、1、2,需要三层for循环来遍历other
if input_tensor.dim() == 2:
for i in range(index_tensor.size()[0]):
for j in range(index_tensor.size()[1]):
if dim == 0: # self矩阵的第0维索引
self_tensor[index_tensor[i][j]][j] += other_tensor[i][j]
elif dim == 1: # self矩阵的第1维索引
self_tensor[i][index_tensor[i][j]] += other_tensor[i][j]
elif input_tensor.dim() == 3:
pass
return self_tensor
if __name__ == '__main__':
index_tensor = torch.tensor([[0, 0], [1, 1]])
print('index_tensor: \n', index_tensor.dim())
self_tensor = torch.arange(1, 5).view(2, 2)
print('self_tensor: \n', self_tensor)
other_tensor = torch.arange(5, 9).view(2, 2)
print('other_tensor: \n', other_tensor)
dim = 1
for i in range(index_tensor.size()[0]):
for j in range(index_tensor.size()[1]):
replace_index = index_tensor[i][j]
print(i, j, replace_index)
if dim == 0:
# self矩阵的第0维索引
self_tensor[replace_index][j] += other_tensor[i][j]
elif dim == 1:
# self矩阵的第1维索引
self_tensor[i][replace_index] += other_tensor[i][j]
print(self_tensor)
index_tensor = torch.tensor([[0, 1], [1, 1]])
print('index_tensor: \n', index_tensor)
self_tensor = torch.arange(0, 4).view(2, 2)
print('self_tensor: \n', self_tensor)
other_tensor = torch.arange(5, 9).view(2, 2)
print('other_tensor: \n', other_tensor)
self_tensor.scatter_add_(dim=0, index=index_tensor, src=other_tensor)
print(self_tensor)
import numpy as np
a = np.arange(3 * 4 * 5).reshape((3, 4, 5))
print(a)
"""
[[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]]
[[20 21 22 23 24]
[25 26 27 28 29]
[30 31 32 33 34]
[35 36 37 38 39]]
[[40 41 42 43 44]
[45 46 47 48 49]
[50 51 52 53 54]
[55 56 57 58 59]]]
"""
# 三维数组,下标,举例(1,2,0)
# 第一种方式,所有语言共同的读取方式,一般通过多层循环嵌套生成不同维度下标来读取
print(a[1][2][0]) # 30
# 第二种方式,python独有,将需要的三个维度下标位置直接放入中括号中,就可以读取;
# 适合于不同维度的数组,通过已知的下标位置读取值
print(a[1, 2, 0]) # 30
print(a[(1, 2, 0)]) # 30
# 第三种方式,一般将下标位置提前用list存储,只能得到多个list组合的数组;要想达到上述要求,可以将list转为tuple
pos = [1, 2, 0]
print(a[pos]) # 正解:a[tuple(pos)]
"""
[[[20 21 22 23 24]
[25 26 27 28 29]
[30 31 32 33 34]
[35 36 37 38 39]]
[[40 41 42 43 44]
[45 46 47 48 49]
[50 51 52 53 54]
[55 56 57 58 59]]
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]]]
"""
按照dim修改坐标位置时,需要用到深拷贝,可以参考这篇博文。 Java基础-Cloneable接口,深浅拷贝【附python,C++深拷贝、浅拷贝】