out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
index = torch.tensor([[1, 2, 0]])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
dim = 0,所以替换行索引,即input[ index[i][j] ][j],可见整个过程就是将行替换,分别是[1,2,0],而列即为index的列,不发生变化,也为[2,1,0],即取出[(1,0),(2,1),(0,2)]。
当我们熟悉了计算之后,就可以找到其中的逻辑,而不需要每次都带入计算索引。
dim = 0 代表替换行索引,而输入的是行向量,那我们先将列索引写出,即[0,1,2],然后将行索引替换为index,即[1,2,0],合并后就是最终索引[(1,0),(2,1),(0,2)]。
tensor([[4, 8, 3]])
index = torch.tensor([[1, 2, 0]]).t()
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
dim = 0,所以替换行索引,即input[ index[i][j] ][j],可见整个过程就是将行替换,分别是[1,2,0],而列即为index的列,即为[0,0,0],合并即取出[(1,0),(2,0),(0,0)]。
dim = 0 代表替换行索引,而输入的是列向量,因为只有1列,所以列索引即[0,0,0],然后将行索引替换为index,即[1,2,0],合并后索引为[(1,0),(2,0),(0,0)]。
tensor([[4],
[7],
[1]])
index = torch.tensor([[1, 2, 0]])
tensor_1 = tensor_0.gather(1, index)
print(tensor_1)
dim = 1,所以替换列索引,即input[i][ index[i][j] ],可见整个过程就是将列替换,分别是[1,2,0],而行即为index的行,即为[0,0,0],合并即取出[(0,1),(0,2),(0,0)]。
dim = 0 代表替换行索引,而输入的是列向量,因为只有1列,所以行索引即[0,0,0],然后将列索引替换为index,即[1,2,0],合并后索引为[(0,1),(0,2),(0,0)]。
tensor([[2, 3, 1]])
是不是很简单,相信你已经理解了。
https://zhuanlan.zhihu.com/p/352877584 图解PyTorch中的torch.gather函数