关于 np.stack(arrays, axis=0, out=None) 的理解

这个函数看的我头疼,翻来翻去看了好多人的解释,还是一头雾水
所以就自己写了点代码,记录一下

import numpy as np
a = np.array(range(1, 25)).reshape(2, 3, 4)
# 构造一个三维的数组,那么axis就可以取2了,axis最大只能取(维数-1)
array([[[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]],
       [[13, 14, 15, 16],
        [17, 18, 19, 20],
        [21, 22, 23, 24]]])

# axis=0,发现结果和原来一样
np.stack(a, axis=0)
shape:(2, 3, 4)
array([[[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]],
       [[13, 14, 15, 16],
        [17, 18, 19, 20],
        [21, 22, 23, 24]]])

# axis=1,结果是第1维里的2个数组,这2个数组里面相同下标的元素搞一块去了,当然还是三维
# [1, 2, 3, 4] <=> [13, 14, 15, 16]
np.stack(a, axis=1)
shape:(3, 2, 4)
array([[[ 1,  2,  3,  4],
        [13, 14, 15, 16]],
       [[ 5,  6,  7,  8],
        [17, 18, 19, 20]],
       [[ 9, 10, 11, 12],
        [21, 22, 23, 24]]])
        
# axis=2,结果是第2维里的3个数组,这3个数组里面相同下标的元素搞一块去了,当然也还是三维
# [1] <=> [13]
np.stack(a, axis=2)
shape:(3, 4, 2)
array([[[ 1, 13],
        [ 2, 14],
        [ 3, 15],
        [ 4, 16]],
       [[ 5, 17],
        [ 6, 18],
        [ 7, 19],
        [ 8, 20]],
       [[ 9, 21],
        [10, 22],
        [11, 23],
        [12, 24]]])

你可能感兴趣的:(numpy)