PyTorch 笔记(06)— Tensor 索引操作(index_select、masked_select、non_zero、gather)

Tensor 支持与 numpy.ndarray 类似的索引操作,如无特殊说明,索引出来的结果与源 tensor 共享内存,即修改一个,另外一个也会跟着改变。

In [65]: a = t.arange(0,6).reshape(2,3)                                                                                                                              

In [66]: a                                                                                                                                                           
Out[66]: 
tensor([[0, 1, 2],
        [3, 4, 5]])

1. 初级索引

1. 获取第 0 行

In [67]: a[0]          # 第 0 行                                                                                                                                                
Out[67]: tensor([0, 1, 2])

2. 获取第 0 列

In [68]: a[:,0]       # 第 0 列                                                                                                                                                     
Out[68]: tensor([0, 3])

3. 获取第 0 行某个元素

In [69]: a[0][2]         # 第 0 行 第 2 个元素                                                                                                                                    
Out[69]: tensor(2)

In [70]: a[0,2]   # 等价    a[0][2]                                                                                                                                                
Out[70]: tensor(2)

In [71]: a[0, -1]     # 第 0 行 最后一个元素                                                                                                                                       
Out[71]: tensor(2)

4. 获取前 1 行

In [72]: a[:1]      # 前 1 行                                                                                                                                                 
Out[72]: tensor([[0, 1, 2]])

In [73]: a[0:1, 0:2]    # 第 0 行第 0 列 和第 0 行第 1 列                                                                                                                    
Out[73]: tensor([[0, 1]])

In [74]: a[0:2, 1:2]     # 第 0 行第 1 列 和第 1 行第 1 列                                                                                                                       
Out[74]: 
tensor([[1],
        [4]])

In [75]: a[0:2, 0:2]                                                                                                                                                 
Out[75]: 
tensor([[0, 1],
        [3, 4]])

In [76]:      

2. 高级索引

常用选择函数如下表所示:
PyTorch 笔记(06)— Tensor 索引操作(index_select、masked_select、non_zero、gather)_第1张图片

2.1 index_select

index_select(input, dim, index)

  • input 表示输入的变量;
  • dim 表示从第几维挑选数据,类型为 int 值;
  • index 表示从选择维度中的哪个位置挑选数据,类型为 torch.Tensor 类的实例;

t.index_select(a, 0, t.tensor([0, 1])) 表示挑选第 0 维,t.tensor([0, 1]) 表示第 0 行、第 1 行

t.index_select(a, 1, t.tensor([1, 3])) 表示挑选第 1 维,t.tensor([1, 3]) 表示第 1 行、第 3 行(第一行从 0 开始计数)

In [9]: a = t.arange(0, 12).reshape(3,4)                                                                                                                             

In [10]: a                                                                                                                                                           
Out[10]: 
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])

In [11]: b = t.index_select(a, 0, t.tensor([0, 1]))                                                                                                                  

In [12]: b                                                                                                                                                           
Out[12]: 
tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

In [13]: a.index_select(0, t.tensor([0, 1]))                                                                                                                         
Out[13]: 
tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

In [17]: a.index_select(1, t.tensor([1,3]))                                                                                                                          
Out[17]: 
tensor([[ 1,  3],
        [ 5,  7],
        [ 9, 11]])

In [18]: c = t.index_select(a, 1, t.tensor([1, 3]))                                                                                                                  

In [19]: c                                                                                                                                                           
Out[19]: 
tensor([[ 1,  3],
        [ 5,  7],
        [ 9, 11]])

In [20]: 

2.2 masked_select

torch.masked_select(input, mask, out=None)

根据掩码张量 mask 中的二元值,取输入张量中的指定项,将取值返回到一个新的 1D 张量。张量 mask 须跟 input 张量有相同的元素数目,但形状或维度不需要相同。返回的张量不与原始张量共享内存空间。

  • input(Tensor) 输入张量;
  • mask(ByteTensor) 掩码张量,包含了二元索引值;
  • out 目标张量;
In [1]: import torch as t

In [2]: a = t.arange(0, 6).reshape(2, 3)

In [76]: a                                                                                                                                                           
Out[76]: 
tensor([[0, 1, 2],
        [3, 4, 5]])

In [77]: a > 2                                                                                                                                                       
Out[77]: 
tensor([[False, False, False],
        [ True,  True,  True]])

In [78]: a[a>2]          # 选择结果与源 Tensor 不共享内存空间                            
Out[78]: tensor([3, 4, 5])

In [79]: a.masked_select(a>2)      # 等价于  a[a>2]                                                                                                                                
Out[79]: tensor([3, 4, 5])

In [80]: a[a>2][0] = 100                                                                                                                                             

In [81]: a                                                                                                                                                           
Out[81]: 
tensor([[0, 1, 2],
        [3, 4, 5]])

In [82]: a[a>2]                                                                                                                                                      
Out[82]: tensor([3, 4, 5])

In [83]: 

2.3 non_zero

non_zero 返回一个包含输入 input 中非零元素索引的张量。输出张量中的每行包含 input 中非零元素的索引。

如果输入 inputn 维,则输出的索引张量 outsizez x n , 这里 z 是输入张量 input 中所有非零元素的个数。

In [2]: a = t.arange(0, 6).reshape(2, 3)

In [3]: a
Out[3]: 
tensor([[0, 1, 2],
        [3, 4, 5]])

In [4]: type(a>2)
Out[4]: torch.Tensor

In [5]: a.nonzero()
Out[5]: 
tensor([[0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2]])

In [7]: t.nonzero(a!=0)
Out[7]: 
tensor([[0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2]])

In [8]: 

2.4 gather

收集输入的特定维度指定位置的数值。

torch.gather(input, dim, index, out=None) → Tensor
  • input (Tensor) – 源张量,也就是输入的待处理变量;
  • dim (int) – 索引的轴,待操作的维度;
  • index (LongTensor) – 聚合元素的下标
  • out (Tensor, optional) – 目标张量
In [1]: import torch as t

In [8]: a = t.arange(0, 6).reshape(2, 3)

In [9]: a
Out[9]: 
tensor([[0, 1, 2],
        [3, 4, 5]])

In [4]: a.sum(dim=0)
Out[4]: tensor([3, 5, 7])

In [5]: a.sum(dim=1)
Out[5]: tensor([ 3, 12])


In [12]: t.gather(a, 0, t.LongTensor([[0,1,0], [1,0,0]]))
Out[12]: 
tensor([[0, 4, 2],
        [3, 1, 2]])

In [13]: t.gather(a, 1, t.LongTensor([[2,0,1], [1,2,0]]))
Out[13]: 
tensor([[2, 0, 1],
        [4, 5, 3]])

In [14]: 

a.sum(dim=0) 可知当 dim=0 时,是按照列的方向求和的,所以求 t.gather(a, 0, t.LongTensor([[0,1,0], [1,0,0]])) 值时可以按照以下步骤进行:

  1. 取各个元素的列下标,如 [(x,0), (x,1), (x,2)], [(x,0), (x,1), (x,2)]
  2. t.LongTensor([[0,1,0], [1,0,0]]) 值作为行下标, 如 [(0,0), (1,1), (0,2)], [(1,0), (0,1), (0,2)]
  3. 根据步骤 2 得到的索引在 input 中求值,即
    a[0][0] = 0a[1][1] = 4a[0][2] = 2
    a[1][0] = 3a[0][1] = 1a[0][2] = 2
    得到如下值
tensor([[0, 4, 2],
        [3, 1, 2]])

同理,对于 a.sum(dim=1) 可知当 dim=1 时,是按照行的方向求和的,所以求 t.LongTensor([[2,0,1], [1,2,0]]) 值时可以按照以下步骤进行:

  1. 取各个元素的行下标,如 [(0,x), (0,x), (0,x)], [(1,x), (1,x), (1,x)]
  2. t.LongTensor([[2,0,1], [1,2,0]]) 值作为列下标, 如 [(0,2), (0,0), (0,1)], [(1,1), (1,2), (1,0)]
  3. 根据步骤 2 得到的索引在 input 中求值,即
    a[0][2] = 2a[0][0] = 0a[0][1] = 1
    a[1][1] = 4a[1][2] = 5a[1][0] = 3
    得到如下值
tensor([[2, 0, 1],
        [4, 5, 3]])

你可能感兴趣的:(PyTorch)