Pytorch:浅探Tensor的各种索引形式

文章目录

    • 前因
    • 各种Tensor索引操作
      • 1. 简单索引
      • 2. 一般的花式索引
      • 3. 复杂的花式索引
      • 4. Informer代码示例

前因

之前一直以为对ndarray的各种索引切片操作还算得上熟悉,但今天师弟问了我Informer实现中ProbSparse Self-Attention的一些Tensor索引操作,才发现有些操作还不太懂,而网上也缺乏相关的参考资料。因此在一系列探索下,写下了这篇博客。

各种Tensor索引操作

构造示例数组x,为一个三维tesnor:

import torch
x = torch.arange(16).reshape(2,2,4)
x
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]]])

1. 简单索引

最简单的是直接传入标量,分别取对应维度的对应数据即可。需要主要的是:表示取该维度全数据,x:y表示取该维度上[x,y)之间的数据。

示例1:

x[0, 0, 0], x[0, 0, :], x[0, 0, 0:2]
(tensor(0), tensor([0, 1, 2, 3]), tensor([0, 1]))

其次是在单维度上传入一维tensor数组,就是在对应的维度上依次获取到对应元素即可。

示例2:

x[1, 1, [0,2,1,0]]
tensor([12, 14, 13, 12])

2. 一般的花式索引

在多个维度上传入一维tensor数组,类似于numpy中的花式索引,对应的tensor数组提供索引关系。

示例3:

x[0, [0,1], [2,3]]

获取到的元素为x[0,0,2]x[0,1,3],即[0,1]提供dim 1的索引值,[2,3]提供dim 2的索引值。

tensor([2, 7])

示例4:

x[0, [0,1,0], [2]]

这种做法会利用到广播机制,实际上的操作会变成x[0, [0,1,0], [2,2,2]]

tensor([2, 6, 2])

但如果写成这样就会报错:

x[0, [0,1,0], [2,1]]
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-21-94117f6cd73a> in <module>
----> 1 x[0, [0,1,0], [2,1]]

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [2]

3. 复杂的花式索引

在2的基础上,传入的不止是一维的tensor数组,而是2维甚至多维的。

示例5:

x[0, [[0,1]], [[2,3]]]

dim1上的tensor数组d1[[0,1]],在dim2上的tensor数组d2[[2,3]],同样进行花式索引构成[x[0,0,2], x[0,1,3]]注意输出多了一层维度

去掉dim0来看,获取的元素位置就是[(0,2), (1,3)]

tensor([[2, 7]])

示例6:

x[0, [[0,1],[0,1]], [[0,1],[2,3]]]

dim1上的tensor数组d1[[0,1],[0,1]],在dim2上的tensor数组d2[[0,1],[2,3]],分别进行花式索引构成[[x[0,0,0], x[0,1,1]],[x[0,0,2], x[0,1,3]]]

去掉dim0来看,获取的元素位置就是[(0,0), (1,1)], [(0,2), (1,3)]

tensor([[0, 5],
        [2, 7]])

同样可以利用到类似于示例3的广播机制,如下所示:

x[0, [[0,1],[0,1]], [0,1]], x[0, [[0,1],[0,1]], [0]]
tensor([[0, 5],        tensor([[0, 4],
        [0, 5]])				[0, 4]])

以及更复杂的示例:

x[0, [[0],[1]], [[0,1],[2,3]]]

等效于x[0, [[0, 0], [1, 1]], [[0,1],[2,3]]]

tensor([[0, 1],
        [6, 7]])

4. Informer代码示例

在informer的_prob_QK函数中,有一段代码是为了从 Q Q Q中根据索引 M t o p M_{top} Mtop得到 Q r e d u c e Q_{reduce} Qreduce,其中 Q Q Q的维度尺寸为 (B,H,L,E) M t o p M_{top} Mtop(B,H,X),得到的 Q r e d u c e Q_{reduce} Qreduce(B,H,X,E)

乍一看,其实不是很好去实现这种功能,不能通过普通的Tensor索引去获取到对应元素,最朴素的想法是遍历B*H遍,然后分别获取对应的Q[b, h, x],然后再拼接起来。但这样的时间复杂度会达到O(b * h)的级别,怎么使用矩阵机制呢,官方代码给出的实现如下:

Q_reduce = Q[torch.arange(B)[:, None, None],torch.arange(H)[None, :, None],M_top, :]

解释:其中torch.arange(B)[:, None, None]先生成一个[0,B)之间的一维数组,再扩充成3维,最后的维度为 (B, 1, 1)【后面称Mb】,同理torch.arange(H)[None, :, None]最后维度为 (1, H, 1)【后面称Mh】,而 M t o p M_{top} Mtop(B,H,X)

首先会利用上面描述的广播机制,将Mb和Mh的维度扩充成(B,H,X),值得注意的是Mb是从(B,1,1)扩充的,这意味着只要确定的dim 0,Mb中的所有值都是一样的,比如Mb[0]里面就是一个(H,X)维的全0矩阵;同理Mh只要确定了dim 1,那么剩余的都是一样的值。

利用广播机制后,维度全部扩充为 (B, H, X),再进行花式索引,分别获取得到对应的值。

下面写个示例验证一下:

import torch
Q = torch.arange(16).reshape(2, 2, 4)
M_top = torch.randint(4, (2,2,2))
(tensor([[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7]],
 
         [[ 8,  9, 10, 11],
          [12, 13, 14, 15]]]),
 tensor([[[3, 0],
          [2, 0]],
 
         [[2, 0],
          [1, 3]]]))

我们要想从Q【2,2,4】中根据M_top【2,2,2】来获取对应的值,得到一个Q_reduce【2,2,2】,实现如下:

Q[torch.arange(2)[:, None, None], torch.arange(2)[None,:,None], M_top]
tensor([[[ 3,  0],
         [ 6,  4]],

        [[10,  8],
         [13, 15]]])

为了方便理解,这里显示Mb,Mh,M_top经过广播后的结果:
Pytorch:浅探Tensor的各种索引形式_第1张图片


以上均为个人理解和实验推理,如果不对和待补充的地方,还请指正。

你可能感兴趣的:(快乐ML/DL,pytorch,python,深度学习)