Pytorch中tensor的索引

Pytorch中tensor的索引

    • 引言
    • 分类
    • 概述
    • 例子
      • 1、0维的索引
      • 2、1维的索引
      • 3、2维的索引
    • 参考

引言

做了一段时间的目标检测,在这个过程中也复现了不少经典的检测网络,例如,faster-rcnn,yolov3,retinanet等等。在学习,消化别人代码的过程中,经常会遇到一个多维的tensor来索引另一个多维的tensor这种种类似的情况,而我对索引的概念还停留在a[0],a[0,:],a[0,…]的阶段,因此在理解代码的过程中需要多次调试,来查看tensor的shape,dytpe。最近经过多方面的调研,总结,特写下以下心得。

分类

tensor的索引共分为两种情况,一是整型(int)的索引,另一种是布尔型(bool)索引。此处注意,索引也是一个tensor。
int型索引中,索引的维度可以是0维,1维,2维…,常使用的a[1,:]中,1是被看成一个0维tensor。而bool型索引的维度是根据被索引的tensor决定的(以下用Tensor表示被索引的tensor)。

概述

可进行索引的位置是由Tensor的维度决定的,例如Tensor的维度为3,则有3个不同的索引位置,这个相信大家都能理解。当索引位置出现int索引或者bool型索引时,其中的区别在于:bool型索引考虑这个维度中的数保留与否,而int型考虑的是这个维度中的数保留哪个。下面通过例子进行说明。

例子

1、0维的索引

int型

 Tensor = torch.randn(5, 7, 3)
 idx = torch.tensor(4,dtype=torch.long)
 a = Tensor[idx]

则a的shape为(7,3)

bool型不存在0维索引

2、1维的索引

int型

Tensor = torch.randn(5, 7, 3)
idx = torch.tensor([0,4,3,1],dtype=torch.long)
a = Tensor[idx]

则a的shape为(4,7,3)

int型多重1维的索引

Tensor = torch.randn(5, 7, 3)
idx1 = torch.tensor([0, 4, 2],dtype=torch.long)
idx2 = torch.tensor([5, 2, 6],dtype=torch.long)
a = Tensor[idx1,idx2]

则a的shape为(3,3)

bool型

Tensor = torch.randn(5, 7, 3)
idx = torch.tensor([0, 0, 1, 1, 0], dtype=torch.bool)
a = Tensor[idx]

则a的shape为(2,7,3)

3、2维的索引

int型

Tensor = torch.randn(5, 7, 3)
idx = torch.tensor([[0, 1], [4, 3]],dytpe=torch.long)
a = Tensor[idx]

则a的shape为(2,2,7,3)

int型多重2维的索引

Tensor = torch.arange(0, 12).view(4, 3)
rows_idx = torch.tensor([[0, 0], [3, 3]],dtype=torch.long)
columns_idx = torch.tensor([[0, 2], [0, 2]],dtype=torch.long)
a = Tensor[rows_idx,columns_idx]

则a的shape为(2,2,3)

bool型

Tensor = torch.randn(5, 7, 3)
idx = torch.randn(5, 7)
num_one= (idx > 0).sum()
a = Tensor[idx>0]

则a的shape为(num_one,3)

参考

[1]https://discuss.pytorch.org/t/slice-tensor-using-boolean-tensor/7324
[2]https://github.com/pytorch/pytorch/issues/2405
[3]http://christopher5106.github.io/torch/2019/06/23/torch-uint8-is-numpy_bool.html

你可能感兴趣的:(pytorch,学习)