本文主要参考 https://zhuanlan.zhihu.com/p/352877584 ,在对 torch.gather() 理解之后,总结了一套比较好用的计算方法,下面直接来看例子。
>>> import torch
>>> tensor_0 = torch.arange(3, 12).view(3, 3)
Test 1
>>> index = torch.tensor([[2, 1, 0]])
>>> tensor_1 = tensor_0.gather(0, index)
# tensor_0
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
# tensor_1
tensor([[9, 7, 5]])
首先根据 index 的 shape = (1, 3),我们将这个水平长条对应到 tensor_0 的 [[3, 4, 5]] 上(尽量往左上角放)。gather() 函数中第一个维度参数为 0,表明沿着竖直方向根据 index 替换数值,例如 [[3, 4, 5]] 中的 3,由于其对应位置 index 为 [[2, 1, 0]] 中的 2,因此将 3 替换成该列的 9。以此类推,4 换成 7,5 还是 5。
Test 2
>>> index = torch.tensor([[2, 1, 0]])
>>> tensor_1 = tensor_0.gather(1, index)
# tensor_0
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
# tensor_1
tensor([[5, 4, 3]])
首先根据 index 的 shape = (1, 3),我们将这个水平长条对应到 tensor_0 的 [[3, 4, 5]] 上。gather() 函数中第一个维度参数为 1,表明沿着水平方向根据 index 替换数值,例如 [[3, 4, 5]] 中的 3,由于其对应位置 index 为 [[2, 1, 0]] 中的 2,因此将 3 替换成该行的 5。以此类推,4 还是 4,5 换成 3。
Test 3
>>> index = torch.tensor([[2, 1, 0]]).t() # 表示一列 [[2], [1], [0]]
>>> tensor_1 = tensor_0.gather(1, index)
# tensor_0
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
# tensor_1
tensor([[5],
[7],
[9]])
首先根据 index 的 shape = (3, 1),我们将这个竖直长条对应到 tensor_0 的 [[3], [6], [9]] 上。gather() 函数中第一个维度参数为 1,表明沿着水平方向根据 index 替换数值,例如 [[3], [6], [9]] 中的 3,由于其对应位置 index 为 [[2], [1], [0]] 中的 2,因此将 3 替换成该行的 5。以此类推,6 换成 7,9 还是 9。
Test 4
>>> index = torch.tensor([[0, 2],
[1, 2]])
>>> tensor_1 = tensor_0.gather(1, index)
# tensor_0
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
# tensor_1
tensor([[3, 5],
[7, 8]])
首先根据 index 的 shape = (2, 2),我们将这个水平长条对应到 tensor_0 的 [[3, 4], [6, 7]] 上。gather() 函数中第一个维度参数为 1,表明沿着水平方向根据 index 替换数值,例如 [[3, 4], [6, 7]] 中的 4,由于其对应位置 index 为 [[0, 2], [1, 2]] 中的 2,因此将 4 替换成该行的 5。以此类推,同一行的 3 还是 3,第二行的 6 换成7,7 换成 8。