最近小弟做了一些实验,但是发现我写的代码虽然能够跑通,但是对于gpu的利用率始终在一个比较低的水平,这就很难受,别人的代码2h就跑完了,我得10h,经过排查发现究其原因就是代码的并行化成都不高,在代码中使用了大量的for循环,没有采用矩阵运算,就导致计算非常的慢,于是最近在学习一些大神的代码,遇到了这个在Tensor中的Indexing Multi-dimensional arrays问题。
在解决这个问题之前,需要了解torch中的boardcast机制,详情可见pytorch官网。
简单来说就是两个tensor在判断是否可以广播时,要进行以下两个步骤:
x=torch.ones(5,1,4,2)
y=torch.ones(3,1,1)
(x+y).size() # Tensor(5, 3, 4, 2)
话不多说,直接上问题。
x=torch.ones(5,8,512) # Tensor(5, 8, 512)
y=torch.ones(5,6,2) # Tensor(5, 6, 2)
result = x[torch.arange(x.size(0)), y.permute(2, 1, 0)] # Tensor(2, 6, 5, 512)
通过上述代码可以看到,x的维度为Tensor(5, 8, 512),y的维度为Tensor(5, 6, 2),但是result出来的维度为Tensor(2, 6, 5, 512),这我直接顶不住,不知道为神马。且听下面分析。
a = torch.arange(x.size(0)) # Tensor(5, )
aa = y.permute(2, 1, 0) # Tensor(2, 6, 5)
aaa = [torch.arange(x.size(0)), y.permute(2, 1, 0)] # [Tensor(5, ), Tensor(2, 6, 5)]
可以看到,aaa就对应了result中我妈要取的Multi-dimensional Index,它是一个list,中间有两个元素,每个元素都是一个Tensor,且维度不一样。这里就有一个补充知识,在做这种Multi-dimensional Index操作的时候,list中的元素要么需要保证维度相同,要么需要保证可以广播,因为在维度不相同时便会进行广播操作,详见NumPy文档。
所以此时就要对a和aa进行广播,得到Tensor(2, 6, 5)这个维度,于是索引就变为了[Tensor(2, 6, 5), Tensor(2, 6, 5)]。
那么在索引是list中嵌套list时,是如何根据下标索引元素的,看下面这个例子。
y = np.arange(35).reshape(5,7)
y[np.array([0,2,4]), np.array([0,1,2])]
>>> array([ 0, 15, 30])
可以看到,索引的过程是第一个list中的[0]元素与第二个list中的[0]元素对应着取的,并不是n*n,而是n个一一对应的结果,如果是Tensor的话则会自动做一个concat操作。
那么刚刚的问题其实就变成了:
x[[Tensor(2,6,5)], Tensor[2,6,5]] # 其实就相当于选取两个下标,对应于i和j,选2*6*5次
并且由于list中只有2个元素,第0号元素对应x中第0个维度(即5),第1号元素对应x中第1个维度(即8),得到的维度即为Tensor(2, 6, 5, 512)。
如果换成下面这种写法:
result = x[torch.arange(x.size(0)), y.permute(2, 1, 0), y.permute(2, 1, 0)]
则list中有三个元素,广播过后每个元素的维度为Tensor(2,6,5),且分别对应x的三个维度,出来的维度即为Tensor(2,6,5)了。
Tensor这种操作,当维度高了过后确实无法用空间去想象了,比较抽象,希望后面能慢慢熟悉,提高代码可读性。