python 如何理解 numpy 数组操作中的 axis 参数?

以前在看numpy数组操作的有关axis的操作时, 常常理解不了, 比如像下面这种:

[代码1]
求沿指定轴上的最大值(2维):

import numpy as np
a = np.array([[78, 34, 87, 25, 83], [25, 67, 97, 22, 13], [78, 43, 87, 45, 89]])
print(a.max(axis=0))
print(a.max(axis=1))

打印输出1:

[78 67 97 45 89]
[87 97 89]

[代码2]
求沿指定轴上的最大值(三维):

import numpy as np
a = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
print(a.max(axis=0))
print('\n')
print(a.max(axis=1))
print('\n')
print(a.max(axis=2))

打印输出2:

[[4 5]
 [6 7]]
 
[[2 3]
 [6 7]]
 
[[1 3]
 [5 7]]

想着糊涂, 看着也糊涂, 所以我干脆画个图:

  • 二维的情况:
    python 如何理解 numpy 数组操作中的 axis 参数?_第1张图片

  • 三维的情况:
    python 如何理解 numpy 数组操作中的 axis 参数?_第2张图片

  • 用三维的情况解释一下, axis等于几, 就用那个维度的数字作比较, 比如axis=0, 表示用最外层的那个维度的数字作比较 ( 最外层是啥, 最外层就是 numpy 数组最外面的括号包裹的层 ), 放在图形中就是在 z 轴方向上比较, 以此类推 axis=1 为次外层作比较, axis=2 为…

参考文章: 理解numpy的rollaxis与swapaxes函数
https://blog.csdn.net/liaoyuecai/article/details/80193996

你可能感兴趣的:(深入浅出,python机器学习)