【Pytorch学习笔记】torch.gather()与tensor.scatter_()

torch.gather()

官方解释: torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
这里可以查看官方解释torch.gather()。
【Pytorch学习笔记】torch.gather()与tensor.scatter_()_第1张图片
个人理解:
torch.gather()的作用?torch.gather()可以做什么?
从某个tensor(input)按照某一纬中选取一部分数据
怎么选?
需要提供一个index数组,并且指定一个维度i。
注意:这个index数组除了第i维之外,其他维度的大小要和input保持一致。 数据类型为LongTensor。
结果是什么样的?
out的shape与index的shape一样。

具体例子:
问题描述:
现有两个数据a和b,大小均为[3,4,2]要求获取a第一维最大的坐标对应的b的值。

# 获取a和b
a = torch.softmax(torch.randn(3, 4, 2), dim=1)
b = torch.softmax(torch.randn(3, 4, 2), dim=1)
# 按照a的第一维获取最大值坐标
max_idx = torch.argmax(a, dim=1,keepdim=True)
print(a)
"""
tensor([[[0.2516, 0.1149],
         [0.4224, 0.3796],
         [0.1473, 0.3753],
         [0.1787, 0.1302]],

        [[0.1428, 0.6113],
         [0.3139, 0.1165],
         [0.0442, 0.1050],
         [0.4992, 0.1672]],

        [[0.4751, 0.2854],
         [0.0598, 0.4888],
         [0.1270, 0.1774],
         [0.3381, 0.0484]]])
"""
print(max_idx)
"""
tensor([[[1, 1]],

        [[3, 0]],

        [[0, 1]]])
"""
print(b)
"""
tensor([[[0.2788, 0.1393],
         [0.0464, 0.0557],
         [0.4017, 0.0415],
         [0.2730, 0.7635]],

        [[0.3165, 0.1508],
         [0.0246, 0.0250],
         [0.4308, 0.0922],
         [0.2282, 0.7319]],

        [[0.1720, 0.1322],
         [0.0400, 0.5672],
         [0.0937, 0.1537],
         [0.6943, 0.1468]]])

"""
out = torch.gather(b, dim=1, index=max_idx)
print(out)
"""tensor([[[0.0464, 0.0557]],

        [[0.2282, 0.1508]],

        [[0.1720, 0.5672]]])
"""

具体实现:

1. 确认index的shape和index数组中每个元素的索引坐标:
tensor([[[1, 1]],
        [[3, 0]],
        [[0, 1]]])
shape: [3,1,2] 这里有三个[],所以是3维,但是第二个[]1维的。
六个元素依次是: 1,1,3,0,0,1
他们对应的index数组的索引依次是:
1:[0,0,0]
1:[0,0,1]
3:[1,0,0]
0:[1,0,1]
0:[2,0,0]
1:[2,0,1]

2. 将指定维度的索引值更换为相应的index数组的值。
例如:1:[0,0,0],本例中dim=1,也就是将第二个0换位对应的值(1),[0,0,0][0,1,0]。其他的依次类推即可得到如下结果。
1:[0,0,0][0,1,0]
1:[0,0,1][0,1,1]
3:[1,0,0][1,3,0]
0:[1,0,1][1,0,1]
0:[2,0,0][2,0,0]
1:[2,0,1][2,1,1]

3. 根据2.中的索引数组获取对应的元素。

结合上面的例子再次进行理解:
两个数据a和b,大小均为[3,4,2]。要求获取a的第一维值最大的坐标对应的b的值。
首先a的第一维值最大的坐标的shape应该是[3,4,2][3,1,2],第一维的四组数只留下最大的那一组的坐标,其中[3,1,2]的值就是对应的坐标。
那么要想获取对应的b的值,只需要将第一维的数值更改为相应的值即可。

tensor.scatter_()

官方解释: Tensor.scatter_(dim, index, src, reduce=None) → Tensor
这里可以查看官方解释tensor.scatter_()
【Pytorch学习笔记】torch.gather()与tensor.scatter_()_第2张图片

  • tensor.scatter_()是将从a中筛选的数据填充到b中。改变b的索引,不改变a的索引。
  • torch.gather()是按照某种规则对b的数据进行筛选。不改变b的索引。
  • tensor.scatter_()是左边改变index,torch.gather()是右边改变index。 index数组除了第i维之外,其他维度的大小要和input保持一致。 数据类型为LongTensor。
    注意:torch.gather()的结果是与index数值shape一样,但是tensor.scatter_()可以与a或者b一样。

直观理解
gather 有“采集”的意思,通常是指对自己的东西进行采集。所以对自己需要一个新的index。即右边需要一个新的index。

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

scatter 有“播种”的意思,通常是指将一个东西播种到另外一个地方,另外的地方需要一个新的index。即左边需要新的index。

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

举例

a = torch.softmax(torch.randn(3, 4, 2), dim=1)
b = torch.softmax(torch.randn(3, 4, 2), dim=1)
max_idx = torch.argmax(a, dim=1,keepdim=True)
print(max_idx)
"""
tensor([[[1, 2]],
        [[0, 2]],
        [[0, 0]]])
"""
print(a)
"""
tensor([[[0.0557, 0.1105],
         [0.7900, 0.0657],
         [0.0282, 0.6790],
         [0.1260, 0.1449]],

        [[0.5063, 0.1406],
         [0.4020, 0.0998],
         [0.0374, 0.4958],
         [0.0543, 0.2638]],
         
        [[0.3352, 0.5519],
         [0.2143, 0.1511],
         [0.2010, 0.1961],
         [0.2495, 0.1009]]])
"""
print(b)
"""
tensor([[[0.3729, 0.0256],
         [0.2451, 0.1061],
         [0.1753, 0.0432],
         [0.2067, 0.8251]],

        [[0.0787, 0.2011],
         [0.0785, 0.3608],
         [0.0498, 0.3005],
         [0.7930, 0.1377]],

        [[0.0140, 0.4980],
         [0.0107, 0.1453],
         [0.7798, 0.2431],
         [0.1955, 0.1136]]])
"""
# out = torch.gather(b, dim=1, index=max_idx)
src = torch.zeros_like(a)
out = src.scatter_ (dim=1, index=max_idx, src=b)
print(out)
"""
tensor([[[0.0000, 0.0000],
         [0.3729, 0.0000],
         [0.0000, 0.0256],
         [0.0000, 0.0000]],

        [[0.0787, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.2011],
         [0.0000, 0.0000]],

        [[0.0140, 0.4980],
         [0.0000, 0.0000],
         [0.0000, 0.0000],
         [0.0000, 0.0000]]])
"""

具体实现:

1. 确认index的shape和index数组中每个元素的索引坐标:
tensor([[[1, 2]],
        [[0, 2]],
        [[0, 0]]])
shape: [3,1,2] 这里有三个[],所以是3维,但是第二个[]1维的。
六个元素依次是: 1,2,0,2,0,0
他们对应的index数组的索引依次是:
1:[0,0,0]
2:[0,0,1]
0:[1,0,0]
2:[1,0,1]
0:[2,0,0]
0:[2,0,1]

2. 将指定维度的索引值更换为相应的index数组的值。
例如:1:[0,0,0],本例中dim=1,也就是将第二个0换位对应的值(1),[0,0,0][0,1,0]。其他的依次类推即可得到如下结果。
1:[0,0,0][0,1,0]
2:[0,0,1][0,2,1]
0:[1,0,0][1,0,0]
2:[1,0,1][1,2,1]
0:[2,0,0][2,0,0]
0:[2,0,1][2,0,1]

3. 根据2.中的索引位置填入src中相应的数值。

你可能感兴趣的:(深度学习,pytorch,学习,python)