pytorch中index_select函数的作用选取某一维度上的数据
函数形式为:
index_select(input, dim, index)
input为tensor,dim是维度从0开始,index是一维tensor(向量),表示在这个维度上要选择的下标
下边直接上例子:
import torch
x = torch.Tensor([[[1, 2, 3],
[4, 5, 6]],
[[9, 8, 7],
[6, 5, 4]]])
print(x)
print(x.size())
index = torch.LongTensor([0, 0, 1])
print(torch.index_select(x, 0, index))
print(torch.index_select(x, 0, index).size())
print(torch.index_select(x, 1, index))
print(torch.index_select(x, 1, index).size())
print(torch.index_select(x, 2, index))
print(torch.index_select(x, 2, index).size())
input的张量形状为2×2×3,index为[0, 0, 1]的向量
分别从0、1、2三个维度来使用index_select()函数,并输出结果和形状,维度大于2就会报错因为input最大只有三个维度
输出:
tensor([[[1., 2., 3.],
[4., 5., 6.]],
[[9., 8., 7.],
[6., 5., 4.]]])
torch.Size([2, 2, 3])
tensor([[[1., 2., 3.],
[4., 5., 6.]],
[[1., 2., 3.],
[4., 5., 6.]],
[[9., 8., 7.],
[6., 5., 4.]]])
torch.Size([3, 2, 3])
tensor([[[1., 2., 3.],
[1., 2., 3.],
[4., 5., 6.]],
[[9., 8., 7.],
[9., 8., 7.],
[6., 5., 4.]]])
torch.Size([2, 3, 3])
tensor([[[1., 1., 2.],
[4., 4., 5.]],
[[9., 9., 8.],
[6., 6., 5.]]])
torch.Size([2, 2, 3])
对结果进行分析:
index是大小为3的向量,输入的张量形状为2×2×3
dim = 0时,输出的张量形状为3×2×3
dim = 1时,输出的张量形状为2×3×3
dim = 2时,输出的张量形状为2×2×3
注意输出张量维度的变化与index大小的关系,结合输出的张量与原始张量来分析index_select()函数的作用