torch.gather函数用法

torch.gather 函数用于从输入张量中收集(或选择)指定位置的元素,然后将它们放入一个新的张量中。这对于根据索引从输入张量中检索值非常有用。torch.gather 的用法如下:
torch.gather(input, dim, index, out=None)

参数说明:

1.input:输入张量,从中收集数据。
2.dim:要在哪个维度上收集数据。这是一个整数,指定了应该沿哪个维度进行选择。例如,如果 dim=1,则表示沿着输入张量的第二个维度进行选择。
3.index:用于指定要收集的元素的索引的张量。这个索引张量的形状必须与 input 在 dim 维度上广播兼容。
4.out:可选参数,用于指定结果存储的张量。

示例:

import torch

# 创建一个输入张量
input = torch.tensor([[1, 2, 3],
                      [4, 5, 6],
                      [7, 8, 9]])

# 创建一个索引张量,指定要收集的元素的位置
index = torch.tensor([[0, 2],
                      [1, 0],
                      [2, 1]])

# 使用 torch.gather 收集指定位置的元素
output1 = torch.gather(input, dim=0, index=index)
output2 = torch.gather(input, dim=1, index=index)
print(output)

输出

tensor([[1, 8],
        [4, 2],
        [7, 5]]) tensor([[1, 3],
        [5, 4],
        [9, 8]])

首先这是个索引函数,最后索引出来的形状

这是dim=0的结果的元素在input里面的位置
[0,0]  [2,1]
[1,0]  [0,1]
[2,0]  [1,1]
这是dim=1的结果的元素在input里面的位置
[0,0] [0,2]
[1,1] [1,0]
[2,2] [2,1]

这个dim就是索引的维度,其他维度的值不变,只改变这个维度的索引值,最后的形状与index形状一致,所以index的形状要小于input的形状。
但是我觉得从维度理解还是很难理解,我觉得我们最方便理解的其实是dim=1,就是一行一行地看,看这行选哪个。

你可能感兴趣的:(python,pytorch,深度学习,人工智能)