np.sum()函数中axis参数的理解

np.sum()函数中axis参数的理解:

import numpy as np
a = np.array([[[1, 2, 3, 2],[1, 2, 3, 1], [2, 3, 4, 1]],
              [[1, 0, 2, 0], [2, 1, 2, 0], [2, 1, 1, 1]]])
print(a.sum(axis=0))
'''
[[2 2 5 2]
 [3 3 5 1]
 [4 4 5 2]]
相当于把第0维压缩(对应值相加)成1:2*3*4 -> (1*)3*4
'''
print(a.sum(axis=1))
'''
[[4  7  10  4]
 [5  2   5  1]]
相当于把第1维压缩(对应值相加)成1:2*3*4 -> 2*(1*)4
'''
print(a.sum(axis=2))
'''
[[ 8  7 10]
 [ 3  5  5]]
相当于把第2维压缩(对应值相加)成1:2*3*4 -> 2*3(*1)
 '''

你可能感兴趣的:(np.sum()函数中axis参数的理解)