另一种选择是多维列表位置索引:
import numpy as np
ncol = 10 # 10 in your case
nrow = 500 # 500 in your case
# just creating some test data:
x = np.arange(ncol*nrow).reshape(nrow,ncol)
y = (ncol * np.random.random_sample((nrow, 1))).astype(int)
print(x)
print(y)
print(x[np.arange(nrow),y.T].T)
语法解释为here.您基本上需要每个维度的索引数组.在第一个维度中,在您的情况下,这只是[0,…,500],第二个维度是您的y数组.我们需要转置它(.T),因为它必须具有与第一个和输出数组相同的形状.第二个换位不是真的需要,但给你你想要的形状.
编辑:
性能问题出现了,我尝试了迄今为止提到的三种方法.你需要line_profiler运行以下内容
kernprof -l -v tmp.py
其中tmp.py是:
import numpy as np
@profile
def calc(x,y):
z = np.arange(nrow)
a = x[z,y.T].T # mine, with the suggested speed up
b = x[:,y].diagonal().T # Christoph Terasa
c = np.array([i[j] for i, j in zip(x, y)]) # tobias_k
return (a,b,c)
ncol = 5 # 10 in your case
nrow = 10 # 500 in your case
x = np.arange(ncol*nrow).reshape(nrow,ncol)
y = (ncol * np.random.random_sample((nrow, 1))).astype(int)
a, b, c = calc(x,y)
print(a==b)
print(b==c)
我的python 2.7.6的输出:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 @profile
4 def calc(x,y):
5 1 4 4.0 0.1 z = np.arange(nrow)
6 1 35 35.0 0.8 a = x[z,y.T].T
7 1 3409 3409.0 76.7 b = x[:,y].diagonal().T
8 501 995 2.0 22.4 c = np.array([i[j] for i, j in zip(x, y)])
9
10 1 1 1.0 0.0 return (a,b,c)
其中%Time或Time是相关列.我不知道如何描述内存消耗,其他人则必须这样做.现在看起来我的解决方案对于所请求的尺寸来说是最快的.