深入解析Tensor索引中的Indexing Multi-dimensional arrays问题

写在前面

最近小弟做了一些实验,但是发现我写的代码虽然能够跑通,但是对于gpu的利用率始终在一个比较低的水平,这就很难受,别人的代码2h就跑完了,我得10h,经过排查发现究其原因就是代码的并行化成都不高,在代码中使用了大量的for循环,没有采用矩阵运算,就导致计算非常的慢,于是最近在学习一些大神的代码,遇到了这个在Tensor中的Indexing Multi-dimensional arrays问题。

前置知识

在解决这个问题之前,需要了解torch中的boardcast机制,详情可见pytorch官网。
简单来说就是两个tensor在判断是否可以广播时,要进行以下两个步骤:

x=torch.ones(5,1,4,2)
y=torch.ones(3,1,1)
(x+y).size() # Tensor(5, 3, 4, 2)
  1. 将两个tensor按照尾部的维度对齐,即x的最后一个维度2与y的最后一个维度1对齐,x的倒数第二个维度1与y的倒数第二个维度1对齐,直到其中的一个tensor没有维度就停止。在这里就是y的第一个维度3停止,对应于x的第二个维度1.
  2. 在每次对齐中,如果两个维度不一样,且其中一个维度为1,那么就把1变成另维度,比如最后一个维度是1和2,那么就把1变成2(代表着y在最后一个维度(列)上copy了一次);比如3和1,那么就把1变成3(代表着x在第二个维度上copy了两次)
  • PS:如果一对对应的维度数字中,两个数字不同,并且没有一个维度为1,那么就会报错

问题描述

话不多说,直接上问题。

x=torch.ones(5,8,512) # Tensor(5, 8, 512)
y=torch.ones(5,6,2) # Tensor(5, 6, 2)
result = x[torch.arange(x.size(0)), y.permute(2, 1, 0)] # Tensor(2, 6, 5, 512)

通过上述代码可以看到,x的维度为Tensor(5, 8, 512),y的维度为Tensor(5, 6, 2),但是result出来的维度为Tensor(2, 6, 5, 512),这我直接顶不住,不知道为神马。且听下面分析。

理性分析

a = torch.arange(x.size(0)) # Tensor(5, )
aa = y.permute(2, 1, 0) # Tensor(2, 6, 5)
aaa = [torch.arange(x.size(0)), y.permute(2, 1, 0)] # [Tensor(5, ), Tensor(2, 6, 5)]

可以看到,aaa就对应了result中我妈要取的Multi-dimensional Index,它是一个list,中间有两个元素,每个元素都是一个Tensor,且维度不一样。这里就有一个补充知识,在做这种Multi-dimensional Index操作的时候,list中的元素要么需要保证维度相同,要么需要保证可以广播,因为在维度不相同时便会进行广播操作,详见NumPy文档。

所以此时就要对a和aa进行广播,得到Tensor(2, 6, 5)这个维度,于是索引就变为了[Tensor(2, 6, 5), Tensor(2, 6, 5)]。

那么在索引是list中嵌套list时,是如何根据下标索引元素的,看下面这个例子。

 y = np.arange(35).reshape(5,7)
 y[np.array([0,2,4]), np.array([0,1,2])]
 >>> array([ 0, 15, 30])

可以看到,索引的过程是第一个list中的[0]元素与第二个list中的[0]元素对应着取的,并不是n*n,而是n个一一对应的结果,如果是Tensor的话则会自动做一个concat操作。

那么刚刚的问题其实就变成了:

x[[Tensor(2,6,5)], Tensor[2,6,5]] # 其实就相当于选取两个下标,对应于i和j,选2*6*5次

并且由于list中只有2个元素,第0号元素对应x中第0个维度(即5),第1号元素对应x中第1个维度(即8),得到的维度即为Tensor(2, 6, 5, 512)。

如果换成下面这种写法:

result = x[torch.arange(x.size(0)), y.permute(2, 1, 0), y.permute(2, 1, 0)]

则list中有三个元素,广播过后每个元素的维度为Tensor(2,6,5),且分别对应x的三个维度,出来的维度即为Tensor(2,6,5)了。

写在最后

Tensor这种操作,当维度高了过后确实无法用空间去想象了,比较抽象,希望后面能慢慢熟悉,提高代码可读性。

你可能感兴趣的:(pytorch,python,深度学习)