torch.gather(...)

1. Abstract

对于 pytorch 中的函数

torch.gather(
	input,  # (Tensor) the source tensor
	dim,    # (int)    the axis along which to index
	index,  # (LongTensor) the indices of elements to gather
	*,
	sparse_grad=False,
	out=None
) → Tensor

有点绕,很多博客画各种图讲各种故事来解释如何input 张量中 gather 位置 index 处的值,乱七八糟,我是都没看明白。所以去官网看了文档:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

从这三行看,意思还是很明晰的:输出 out 和输入 input 之间的差别就是,把相应位置(dim)的下标替换成 index[i][j][k]dim=0,1,2 分别对应替换的位置0,1,2。但这不够直观!

【注】从上面三行代码可以看出,indexinput 的维度尺寸是一样的,即 len(index.shape) == len(input.shape),但不一定是相同的形状:index.shape[dim] ≠ input.shape[dim](其他维度的形状必须满足 index.shape <= input.shape)。

2. 图解

2.1 一维向量

先从简单的一维向量看看,假设:

x = torch.tensor([3, 4, 5, 6, 7])
index = torch.tensor([4, 4, 1, 1, 0, 3])
out = torch.gather(x, dim=0, index=index)

按规则看:

out[i] = input[index[i]]  # dim == 0

即,从向量里选取指定位置 index[i] 处的数字,放到输出向量 out[i] 处。这个很好理解,pythonnumpypytorch 都有这样的语法:

y = x[index]
print(y)
### output ###
tensor([7, 7, 4, 4, 3, 6])

大家应该很熟悉吧!现在,用图的方式展现这个过程:
torch.gather(...)_第1张图片
torch.gather(...) 函数,就是这样的:

out = torch.gather(x, dim=0, index=index)
print(out)
### output ###
tensor([7, 7, 4, 4, 3, 6])

举例来说,上面的 index[4] = 0,那么它会寻找 input[index[4]] = input[0] = 3,然后放入 out[4]。这就是英文单词 gather 的意思。

index 的长度是不受限制的,即 gather 多少元素都可以。

小结:在一维向量下,out = torch.gather(x, dim=0, index=index) 等价于 out = x[index]

2.2 二维矩阵

往上升一个维度,看看对二维矩阵实施 gather 函数的操作:

x = torch.tensor([[3, 4, 5, 6, 7], [9, 8, 7, 6, 5]])
idx = torch.randint(low=0, high=5, size=(2, 6))
y = torch.gather(x, dim=1, index=idx)
print(x)
print(idx)
print(y)
### output ###
tensor([[3, 4, 5, 6, 7],
        [9, 8, 7, 6, 5]])
tensor([[4, 4, 1, 1, 0, 3],
        [0, 1, 2, 1, 4, 1]])
tensor([[7, 7, 4, 4, 3, 6],
        [9, 8, 7, 8, 5, 8]])

按规则看:

out[i][j] = input[i][index[i][j]]  # dim == 1`

即,从向量 input[i] 里选取指定位置 index[i][j] 处的数字,放到输出向量 out[i][j] 处。也许多了一个维度就有点绕了,但仔细观察,我们可以假定 i = 0,此时:

out[0][j] = input[0][index[0][j]]  # 对应下图的左侧

若假定 i = 1,则:

out[1][j] = input[1][index[1][j]]  # 对应下图的右侧

a = input[i]b = index[i]c = out[i],我们就回到了一维向量gather 操作:

c[j] = a[b[j]]

即,输出 c = out[i] 是对输入 a = input[i] 执行了一次与一维向量时一样的 gather 操作,其中下标是 b = index[i]。在二维矩阵上的 gather 操作,不过是并行地执行了多个一维向量的 gather

示意图:
torch.gather(...)_第2张图片
【图注】input 矩阵中,稍深蓝色的行表示当前被 gather 的行,左图 gather 第一行,右图 gather 第二行。

上面是 dim = 1 时的情况,是沿着矩阵的进行 gather,当 dim = 0 时,就是沿着进行 gather

out[i][0] = input[index[i][0]][0]  # dim == 0
out[i][1] = input[index[i][1]][1]
...

torch.gather(...)_第3张图片
【图注】input 的稍深蓝色的列是正在被 gather 的列。

也就是并行地执行多个列向量gather,每列 index 是一个并行分支,并行分支的数量可以小于 input 的列数,但不能超过,超过的话,它 gather 哪一列呢?

小结:二维矩阵的 gather 操作就是并行地执行了多个一维向量的 gather 操作;dim=1 按行 gatherdim=0 按列 gather

2.3 高维张量

弄懂一维到二维的 gather,更高维的操作也就清晰了,就是画图有一点难画。假设

x = tensor([[[ 0,  1,  2,  3,  4],
             [ 5,  6,  7,  8,  9]],

		    [[10, 11, 12, 13, 14],
             [15, 16, 17, 18, 19]],

		    [[20, 21, 22, 13, 24],
             [25, 26, 27, 28, 29]]])

则当 dim == 0 时,是沿着第一维进行 gather 的,那么 index.shape[0] (一个并行分支 gather 的元素的数量) 可为任意数,这里设置为 4,其他 index.shape[i≠0] <= input.shape[i≠0]

index = tensor([[[1, 2, 2],
         		 [2, 2, 0]],

       			[[0, 0, 1],
         		 [1, 0, 1]],

		        [[2, 0, 0],
        		 [0, 1, 2]],

		        [[1, 1, 0],
        		 [0, 0, 0]]])

index.shape == (4, 2, 3),执行:

y = torch.gather(x, dim=0, index=index)

的示意图如下:

只画了看得见的前两列(两个并行 gather 分支)。红色和绿色箭头表示两列下标沿着 dim=0 进行 gather 操作,每一列和一维向量的 gather 是一样的,只不过这里有 2*3 个列。

再往高维拓展,也是一样,都是从基本的一维向量 gather 拓到并行 gather

你可能感兴趣的:(python,pytorch,python,pytorch)