官方解释: torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
这里可以查看官方解释torch.gather()。
个人理解:
torch.gather()的作用?torch.gather()可以做什么?
从某个tensor(input)按照某一纬中选取一部分数据。
怎么选?
需要提供一个index数组,并且指定一个维度i。
注意:这个index数组除了第i维之外,其他维度的大小要和input保持一致。 数据类型为LongTensor。
结果是什么样的?
out的shape与index的shape一样。
具体例子:
问题描述:
现有两个数据a和b,大小均为[3,4,2]要求获取a第一维最大的坐标对应的b的值。
# 获取a和b
a = torch.softmax(torch.randn(3, 4, 2), dim=1)
b = torch.softmax(torch.randn(3, 4, 2), dim=1)
# 按照a的第一维获取最大值坐标
max_idx = torch.argmax(a, dim=1,keepdim=True)
print(a)
"""
tensor([[[0.2516, 0.1149],
[0.4224, 0.3796],
[0.1473, 0.3753],
[0.1787, 0.1302]],
[[0.1428, 0.6113],
[0.3139, 0.1165],
[0.0442, 0.1050],
[0.4992, 0.1672]],
[[0.4751, 0.2854],
[0.0598, 0.4888],
[0.1270, 0.1774],
[0.3381, 0.0484]]])
"""
print(max_idx)
"""
tensor([[[1, 1]],
[[3, 0]],
[[0, 1]]])
"""
print(b)
"""
tensor([[[0.2788, 0.1393],
[0.0464, 0.0557],
[0.4017, 0.0415],
[0.2730, 0.7635]],
[[0.3165, 0.1508],
[0.0246, 0.0250],
[0.4308, 0.0922],
[0.2282, 0.7319]],
[[0.1720, 0.1322],
[0.0400, 0.5672],
[0.0937, 0.1537],
[0.6943, 0.1468]]])
"""
out = torch.gather(b, dim=1, index=max_idx)
print(out)
"""tensor([[[0.0464, 0.0557]],
[[0.2282, 0.1508]],
[[0.1720, 0.5672]]])
"""
具体实现:
1. 确认index的shape和index数组中每个元素的索引坐标:
tensor([[[1, 1]],
[[3, 0]],
[[0, 1]]])
shape: [3,1,2] 这里有三个[],所以是3维,但是第二个[]是1维的。
六个元素依次是: 1,1,3,0,0,1
他们对应的index数组的索引依次是:
1:[0,0,0]
1:[0,0,1]
3:[1,0,0]
0:[1,0,1]
0:[2,0,0]
1:[2,0,1]
2. 将指定维度的索引值更换为相应的index数组的值。
例如:1:[0,0,0],本例中dim=1,也就是将第二个0换位对应的值(1),即[0,0,0]→[0,1,0]。其他的依次类推即可得到如下结果。
1:[0,0,0] → [0,1,0]
1:[0,0,1] → [0,1,1]
3:[1,0,0] → [1,3,0]
0:[1,0,1] → [1,0,1]
0:[2,0,0] → [2,0,0]
1:[2,0,1] → [2,1,1]
3. 根据2.中的索引数组获取对应的元素。
结合上面的例子再次进行理解:
两个数据a和b,大小均为[3,4,2]。要求获取a的第一维值最大的坐标对应的b的值。
首先a的第一维值最大的坐标的shape应该是[3,4,2]→[3,1,2],第一维的四组数只留下最大的那一组的坐标,其中[3,1,2]的值就是对应的坐标。
那么要想获取对应的b的值,只需要将第一维的数值更改为相应的值即可。
官方解释: Tensor.scatter_(dim, index, src, reduce=None) → Tensor
这里可以查看官方解释tensor.scatter_()
直观理解
gather 有“采集”的意思,通常是指对自己的东西进行采集。所以对自己需要一个新的index。即右边需要一个新的index。
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
scatter 有“播种”的意思,通常是指将一个东西播种到另外一个地方,另外的地方需要一个新的index。即左边需要新的index。
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
举例
a = torch.softmax(torch.randn(3, 4, 2), dim=1)
b = torch.softmax(torch.randn(3, 4, 2), dim=1)
max_idx = torch.argmax(a, dim=1,keepdim=True)
print(max_idx)
"""
tensor([[[1, 2]],
[[0, 2]],
[[0, 0]]])
"""
print(a)
"""
tensor([[[0.0557, 0.1105],
[0.7900, 0.0657],
[0.0282, 0.6790],
[0.1260, 0.1449]],
[[0.5063, 0.1406],
[0.4020, 0.0998],
[0.0374, 0.4958],
[0.0543, 0.2638]],
[[0.3352, 0.5519],
[0.2143, 0.1511],
[0.2010, 0.1961],
[0.2495, 0.1009]]])
"""
print(b)
"""
tensor([[[0.3729, 0.0256],
[0.2451, 0.1061],
[0.1753, 0.0432],
[0.2067, 0.8251]],
[[0.0787, 0.2011],
[0.0785, 0.3608],
[0.0498, 0.3005],
[0.7930, 0.1377]],
[[0.0140, 0.4980],
[0.0107, 0.1453],
[0.7798, 0.2431],
[0.1955, 0.1136]]])
"""
# out = torch.gather(b, dim=1, index=max_idx)
src = torch.zeros_like(a)
out = src.scatter_ (dim=1, index=max_idx, src=b)
print(out)
"""
tensor([[[0.0000, 0.0000],
[0.3729, 0.0000],
[0.0000, 0.0256],
[0.0000, 0.0000]],
[[0.0787, 0.0000],
[0.0000, 0.0000],
[0.0000, 0.2011],
[0.0000, 0.0000]],
[[0.0140, 0.4980],
[0.0000, 0.0000],
[0.0000, 0.0000],
[0.0000, 0.0000]]])
"""
具体实现:
1. 确认index的shape和index数组中每个元素的索引坐标:
tensor([[[1, 2]],
[[0, 2]],
[[0, 0]]])
shape: [3,1,2] 这里有三个[],所以是3维,但是第二个[]是1维的。
六个元素依次是: 1,2,0,2,0,0
他们对应的index数组的索引依次是:
1:[0,0,0]
2:[0,0,1]
0:[1,0,0]
2:[1,0,1]
0:[2,0,0]
0:[2,0,1]
2. 将指定维度的索引值更换为相应的index数组的值。
例如:1:[0,0,0],本例中dim=1,也就是将第二个0换位对应的值(1),即[0,0,0]→[0,1,0]。其他的依次类推即可得到如下结果。
1:[0,0,0] → [0,1,0]
2:[0,0,1] → [0,2,1]
0:[1,0,0] → [1,0,0]
2:[1,0,1] → [1,2,1]
0:[2,0,0] → [2,0,0]
0:[2,0,1] → [2,0,1]
3. 根据2.中的索引位置填入src中相应的数值。