torch.tensor.gather用法

torch.tensor.gather用法

  • 介绍
    • 示例代码1
    • 示例代码2

介绍

官网介绍:https://pytorch.org/docs/stable/generated/torch.Tensor.gather.html
torch.Tensor.gather是PyTorch中的一个函数,它根据索引从输入张量中收集值。

示例代码1

以下是一个使用torch.Tensor.gather的示例:

import torch

# 创建一个输入张量
input = torch.tensor([[1, 2], [3, 4]])

# 创建一个索引张量
index = torch.tensor([[0, 0], [1, 0]])

# 使用gather函数
output = input.gather(1, index)
print(output)
# tensor([[1, 1],
#         [4, 3]])

在上述代码中,input.gather(1, index)会沿着维度1(列)收集值。索引张量index中的每个值指定了在相应位置收集哪个元素。
input[1, 2]的索引为[0, 1]index[0, 0]用两个0索引拿到的值为[1, 1];同理,input[3, 4]的索引为[0, 1]index[1, 0]用两个0索引拿到的值为[4, 3];因此,output的值为[[1, 1], [4, 3]]

示例代码2

以下是一个使用torch.Tensor.gather的3维tensor示例:

import torch

# 创建一个输入张量
input = torch.arange(0,8).view(2, 2, 2)
# tensor([[[0, 1],
#          [2, 3]],
#         [[4, 5],
#          [6, 7]]])

# 创建一个索引张量
index = torch.tensor([[[0,0]],[[1,0]]])

# 使用gather函数
output = input.gather(1, index)
print(output)
# tensor([[[0, 1]],
#         [[6, 5]]])

# 错误示例
index = torch.tensor([[[0,0]],[[10,0]]])
input.gather(1, index) # RuntimeError: index 10 is out of bounds for dimension 1 with size 2

你可能感兴趣的:(Pytorch,pytorch)