torch.gather() 函数理解

本文主要参考 https://zhuanlan.zhihu.com/p/352877584 ,在对 torch.gather() 理解之后,总结了一套比较好用的计算方法,下面直接来看例子。

>>> import torch
>>> tensor_0 = torch.arange(3, 12).view(3, 3)

Test 1

>>> index = torch.tensor([[2, 1, 0]])
>>> tensor_1 = tensor_0.gather(0, index)

# tensor_0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
# tensor_1 
tensor([[9, 7, 5]])

首先根据 index 的 shape = (1, 3),我们将这个水平长条对应到 tensor_0 的 [[3, 4, 5]] 上(尽量往左上角放)。gather() 函数中第一个维度参数为 0,表明沿着竖直方向根据 index 替换数值,例如 [[3, 4, 5]] 中的 3,由于其对应位置 index 为 [[2, 1, 0]] 中的 2,因此将 3 替换成该的 9。以此类推,4 换成 7,5 还是 5。

Test 2

>>> index = torch.tensor([[2, 1, 0]])
>>> tensor_1 = tensor_0.gather(1, index)

# tensor_0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
# tensor_1 
tensor([[5, 4, 3]])

首先根据 index 的 shape = (1, 3),我们将这个水平长条对应到 tensor_0 的 [[3, 4, 5]] 上。gather() 函数中第一个维度参数为 1,表明沿着水平方向根据 index 替换数值,例如 [[3, 4, 5]] 中的 3,由于其对应位置 index 为 [[2, 1, 0]] 中的 2,因此将 3 替换成该的 5。以此类推,4 还是 4,5 换成 3。

Test 3

>>> index = torch.tensor([[2, 1, 0]]).t()  # 表示一列 [[2], [1], [0]]
>>> tensor_1 = tensor_0.gather(1, index)

# tensor_0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
# tensor_1
tensor([[5],
        [7],
        [9]])

首先根据 index 的 shape = (3, 1),我们将这个竖直长条对应到 tensor_0 的 [[3], [6], [9]] 上。gather() 函数中第一个维度参数为 1,表明沿着水平方向根据 index 替换数值,例如 [[3], [6], [9]] 中的 3,由于其对应位置 index 为 [[2], [1], [0]] 中的 2,因此将 3 替换成该的 5。以此类推,6 换成 7,9 还是 9。

Test 4

>>> index = torch.tensor([[0, 2], 
                          [1, 2]])
>>> tensor_1 = tensor_0.gather(1, index)

# tensor_0
tensor([[ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
# tensor_1
tensor([[3, 5],
        [7, 8]])

首先根据 index 的 shape = (2, 2),我们将这个水平长条对应到 tensor_0 的 [[3, 4], [6, 7]] 上。gather() 函数中第一个维度参数为 1,表明沿着水平方向根据 index 替换数值,例如 [[3, 4], [6, 7]] 中的 4,由于其对应位置 index 为 [[0, 2], [1, 2]] 中的 2,因此将 4 替换成该的 5。以此类推,同一行的 3 还是 3,第二行的 6 换成7,7 换成 8。

你可能感兴趣的:(pytorch,python,经验分享)