在看动手学深度学习 pytorch版本的时候,看到其中使用了index_select方法。感觉这个方法较为常用和有用。所以需要弄懂。
用法如下:
torch.index_select(input, dim, index, out=None) → Tensor
- 参数input(Tensor):表示被选择的tensor
- 参数dim(int):表示是在哪个维度做选择。为1表示,在列上做选择。为0表示在行上做选择。
- 参数index(LongTensor):表示需要选择出的行或者列。 它是一个数组为Long形的tensor。(举一个列子,如果被选择的tensor是一个二维的。index是一维的[1,4,8],dim=0,那么就是选出第2行,第5行,第9行,下标从0开始。当然输入的tensor可能是多维的)
- 参数out(Tensor, optional):表示的输出到到哪个tensor上去,该参数可选。
pytorch官网上的列子:
动手学深度学习(Dive-into-DL-PyTorch) 上写法的列子的列子:
官网讲解该部分地址:https://pytorch.org/docs/stable/torch.html?highlight=index#torch.index_select