pytorch中gather函数的理解

pytorch函数gather理解

torch.gather(input, dim, index, out=None) → Tensor 

Parameters:

  • input (Tensor) – 源张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
  • out (Tensor, optional) – 目标张量

公式含义

这个函数的意义就是可以重新排列特定维度的信息。对一个三维张量,从公式来看,输出是下面这种,就是在特定维度上,用索引index下标代替所在位置的值。

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

直观理解

原始tensor ,名称为a

a = torch.randint(0, 30, (2, 3, 5))

以下以 CxHxW的维度讲述,其中C=2,H=3, W=5,
pytorch中gather函数的理解_第1张图片

index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])

指定dim = 1,也就是在第二个维度上H重新排列,

b = torch.gather(a, 1,index)

此时,第一个维度C是不会改变的,还是存在两个通道C,分别是a[0]和a[1],
针对a[0]或者a[1] , 在高度维度H上,分别是3行,a[0][0:2] a[1][0:2]。即

a[0].shape == [3,5]

因此,如果选择dim=1,则index 张量里面的数必须在0-2之间,不然会越界,
下一步就是选取数字了。
针对每一个通道C,输出张量b,只需要按照index重新排列矩阵即可
例如在第b[0,1,2]的位置,则选择a[0][index[0,1,2]][2]的值进行代替即可。

同理在其他维度也是一样。

注意点

需要注意的是索引矩阵不能越界,例如针对上述a[2,3,5],
如果指定dim=0,则index里面的数不能超过1,指定dim=1,则index不能超过2,指定dim=3,则index不能超过4

本文参考https://www.jianshu.com/p/5d1f8cd5fe31

你可能感兴趣的:(python,pytorch,人工智能,python)