index_select
是 PyTorch 中的一个非常有用的函数,允许从给定的维度中选择指定索引的张量值
torch.index_select(input, dim, index, out=None) -> Tensor
input | 从中选择数据的源张量 |
dim | 从中选择数据的维度 |
index | 一个 1D 张量,包含你想要从 此张量应该是 |
out | 一个可选的参数,用于指定输出张量。 如果没有提供,将创建一个新的张量。 |
import torch
import numpy as np
x = torch.tensor(np.arange(16).reshape(4,4))
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]], dtype=torch.int32)
'''
torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4, 5, 6, 7],
[12, 13, 14, 15]], dtype=torch.int32)
'''
torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1, 3],
[ 5, 7],
[ 9, 11],
[13, 15]], dtype=torch.int32)
'''
import torch
import numpy as np
x = torch.tensor(np.arange(16).reshape(4,4),dtype=torch.float32, requires_grad=True)
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]], requires_grad=True)
'''
torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4., 5., 6., 7.],
[12., 13., 14., 15.]], grad_fn=)
'''
torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1., 3.],
[ 5., 7.],
[ 9., 11.],
[13., 15.]], grad_fn=)
'''