Pytorch的使用:torch.gather函数

Pytorch的使用:torch.gather函数

  • **torch.gather()**
  • 作用:方便从批量tensor中获取特定化维度指定索引下的数据,该索引往往是乱序的。
  • 首先看一下官方文档中的3维数据
    • index 代表输入向量
    • dim 代表替换的维度
    • input 代表最终选取的元素
  • 接下来我们用一个二维的数据,分别采用官方文档的思路和我个人理解的思路进行简化举例
      • 我们采用了一个3×3的二维矩阵进行练习
      • 输出结果
    • 1.index为行向量,dim = 0 替换行索引
      • 计算思路1:
      • 计算思路2:
      • 输出结果
    • 2.index为列向量,dim = 0 替换行索引
      • 计算思路1:
      • 计算思路2:
      • 输出结果
    • 3.index为行向量,dim = 1 替换列索引
      • 计算思路1:
      • 计算思路2:
      • 输出结果
    • 参考文献

torch.gather()

作用:方便从批量tensor中获取特定化维度指定索引下的数据,该索引往往是乱序的。

首先看一下官方文档中的3维数据

index 代表输入向量

dim 代表替换的维度

input 代表最终选取的元素

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

接下来我们用一个二维的数据,分别采用官方文档的思路和我个人理解的思路进行简化举例

我们采用了一个3×3的二维矩阵进行练习

import torch

tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

输出结果

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

1.index为行向量,dim = 0 替换行索引

index = torch.tensor([[1, 2, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

计算思路1:

dim = 0,所以替换行索引,即input[ index[i][j] ][j],可见整个过程就是将行替换,分别是[1,2,0],而列即为index的列,不发生变化,也为[2,1,0],即取出[(1,0),(2,1),(0,2)]。

计算思路2:

当我们熟悉了计算之后,就可以找到其中的逻辑,而不需要每次都带入计算索引。
dim = 0 代表替换行索引,而输入的是行向量,那我们先将列索引写出,即[0,1,2],然后将行索引替换为index,即[1,2,0],合并后就是最终索引[(1,0),(2,1),(0,2)]。

输出结果

tensor([[4, 8, 3]])

2.index为列向量,dim = 0 替换行索引

index = torch.tensor([[1, 2, 0]]).t()
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)

计算思路1:

dim = 0,所以替换行索引,即input[ index[i][j] ][j],可见整个过程就是将行替换,分别是[1,2,0],而列即为index的列,即为[0,0,0],合并即取出[(1,0),(2,0),(0,0)]。

计算思路2:

dim = 0 代表替换行索引,而输入的是列向量,因为只有1列,所以列索引即[0,0,0],然后将行索引替换为index,即[1,2,0],合并后索引为[(1,0),(2,0),(0,0)]。

输出结果

tensor([[4],
        [7],
        [1]])

3.index为行向量,dim = 1 替换列索引

index = torch.tensor([[1, 2, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)

计算思路1:

dim = 1,所以替换列索引,即input[i][ index[i][j] ],可见整个过程就是将列替换,分别是[1,2,0],而行即为index的行,即为[0,0,0],合并即取出[(0,1),(0,2),(0,0)]。

计算思路2:

dim = 0 代表替换行索引,而输入的是列向量,因为只有1列,所以行索引即[0,0,0],然后将列索引替换为index,即[1,2,0],合并后索引为[(0,1),(0,2),(0,0)]。

输出结果

tensor([[2, 3, 1]])

是不是很简单,相信你已经理解了。

参考文献

https://zhuanlan.zhihu.com/p/352877584 图解PyTorch中的torch.gather函数

你可能感兴趣的:(pytorch,pytorch,深度学习,python)