【Python 代码优化记录】遍历 + concatenate

0、背景

【目的】:现有一个数据条数很大(约 25w)的 numpy array:allData。想根据字典 aDict 里的值选取出其中的某些数据(约 2w)条组成一个新的 numpy array:data.

原代码大概的逻辑如下:(只放出了和我猜想和结论相关的部分)

data = np.array([])

for key, value in aDict.items():
    data = np.concatenate([data, allData[key]])

1、问题

代码到后面运行速度变得越来越慢

2、猜测 1:是不是由于字典数据量太大造成的?

【排除方法】:将 for 循环里面的代码注释掉,看运行时间。

【结论】:注释掉之后运行速度很快,所以问题应该出在被注释掉的 for 循环里面的语句。

3、正确推理 & 分析原因

定位到为题所在之后,搜了一下 np.concatenate() 的原理:python - Concatenate Numpy arrays without copying - Stack Overflow

【Python 代码优化记录】遍历 + concatenate_第1张图片

 numpy array 的内存必须是连续的,所以进行 np.concatenate() 时,相当于重新分配了一个大数组,再把要 concat 起来的小数组里的值全部 copy 进去。

所以耗时的原因在这里,在 for 循环里 concatenate,相当于每次都要重新分配并复制,越到后面,需要复制的值越多,所以就越慢。

4、解决方案

在进入 for 循环之前,先分配好最终 numpy array(data需要的所有空间

在 for 循环里,直接往 numpy array(data)里赋值即可。

代码如下:

data = np.empty(shape=(20000, ..., ..., ..., ...))
count = 0

for key, value in aDict.items():
    data[count, :, :, :, :] = allData[idx, :, :, :, :]
    count += 1

你可能感兴趣的:(Python,学习,python,numpy,concatenate,拼接,循环)