pythonnumpy遍历_遍历numpy.array的任意维度

Is there function to get an iterator over an arbitrary dimension of a numpy array?

Iterating over the first dimension is easy...

In [63]: c = numpy.arange(24).reshape(2,3,4)

In [64]: for r in c :

....: print r

....:

[[ 0 1 2 3]

[ 4 5 6 7]

[ 8 9 10 11]]

[[12 13 14 15]

[16 17 18 19]

[20 21 22 23]]

But iterating over other dimensions is harder. For example, the last dimension:

In [73]: for r in c.swapaxes(2,0).swapaxes(1,2) :

....: print r

....:

[[ 0 4 8]

[12 16 20]]

[[ 1 5 9]

[13 17 21]]

[[ 2 6 10]

[14 18 22]]

[[ 3 7 11]

[15 19 23]]

I'm making a generator to do this myself, but I'm surprised there isn't a function named something like numpy.ndarray.iterdim(axis=0) to do this automatically.

解决方案

What you propose is quite fast, but the legibility can be improved with the clearer forms:

for i in range(c.shape[-1]):

print c[:,:,i]

or, better (faster, more general and more explicit):

for i in range(c.shape[-1]):

print c[...,i]

However, the first approach above appears to be about twice as slow as the swapaxes() approach:

python -m timeit -s 'import numpy; c = numpy.arange(24).reshape(2,3,4)' \

'for r in c.swapaxes(2,0).swapaxes(1,2): u = r'

100000 loops, best of 3: 3.69 usec per loop

python -m timeit -s 'import numpy; c = numpy.arange(24).reshape(2,3,4)' \

'for i in range(c.shape[-1]): u = c[:,:,i]'

100000 loops, best of 3: 6.08 usec per loop

python -m timeit -s 'import numpy; c = numpy.arange(24).reshape(2,3,4)' \

'for r in numpy.rollaxis(c, 2): u = r'

100000 loops, best of 3: 6.46 usec per loop

I would guess that this is because swapaxes() does not copy any data, and because the handling of c[:,:,i] might be done through general code (that handles the case where : is replaced by a more complicated slice).

Note however that the more explicit second solution c[...,i] is both quite legible and quite fast:

python -m timeit -s 'import numpy; c = numpy.arange(24).reshape(2,3,4)' \

'for i in range(c.shape[-1]): u = c[...,i]'

100000 loops, best of 3: 4.74 usec per loop

你可能感兴趣的:(pythonnumpy遍历)