pytorch 笔记:index_select

1 基本使用方法

index_select 是 PyTorch 中的一个非常有用的函数,允许从给定的维度中选择指定索引的张量值

torch.index_select(input, dim, index, out=None) -> Tensor
input 从中选择数据的源张量
dim 从中选择数据的维度
index

一个 1D 张量,包含你想要从 dim 维度中选择的索引

此张量应该是 LongTensor 类型

out

一个可选的参数,用于指定输出张量。

如果没有提供,将创建一个新的张量。

2 举例

import torch
import numpy as np

x = torch.tensor(np.arange(16).reshape(4,4))
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]], dtype=torch.int32)
'''

torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]], dtype=torch.int32)
'''

torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1,  3],
        [ 5,  7],
        [ 9, 11],
        [13, 15]], dtype=torch.int32)
'''

3 index_select保存梯度

import torch
import numpy as np

x = torch.tensor(np.arange(16).reshape(4,4),dtype=torch.float32, requires_grad=True)
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.]], requires_grad=True)
'''

torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4.,  5.,  6.,  7.],
        [12., 13., 14., 15.]], grad_fn=)
'''

torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1.,  3.],
        [ 5.,  7.],
        [ 9., 11.],
        [13., 15.]], grad_fn=)
'''

你可能感兴趣的:(pytorch学习,笔记)