torch.gather 函数用于从输入张量中收集(或选择)指定位置的元素,然后将它们放入一个新的张量中。这对于根据索引从输入张量中检索值非常有用。torch.gather 的用法如下:
torch.gather(input, dim, index, out=None)
1.input:输入张量,从中收集数据。
2.dim:要在哪个维度上收集数据。这是一个整数,指定了应该沿哪个维度进行选择。例如,如果 dim=1,则表示沿着输入张量的第二个维度进行选择。
3.index:用于指定要收集的元素的索引的张量。这个索引张量的形状必须与 input 在 dim 维度上广播兼容。
4.out:可选参数,用于指定结果存储的张量。
import torch
# 创建一个输入张量
input = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 创建一个索引张量,指定要收集的元素的位置
index = torch.tensor([[0, 2],
[1, 0],
[2, 1]])
# 使用 torch.gather 收集指定位置的元素
output1 = torch.gather(input, dim=0, index=index)
output2 = torch.gather(input, dim=1, index=index)
print(output)
输出
tensor([[1, 8],
[4, 2],
[7, 5]]) tensor([[1, 3],
[5, 4],
[9, 8]])
首先这是个索引函数,最后索引出来的形状
这是dim=0的结果的元素在input里面的位置
[0,0] [2,1]
[1,0] [0,1]
[2,0] [1,1]
这是dim=1的结果的元素在input里面的位置
[0,0] [0,2]
[1,1] [1,0]
[2,2] [2,1]
这个dim就是索引的维度,其他维度的值不变,只改变这个维度的索引值,最后的形状与index形状一致,所以index的形状要小于input的形状。
但是我觉得从维度理解还是很难理解,我觉得我们最方便理解的其实是dim=1,就是一行一行地看,看这行选哪个。