关于numpy mean函数的axis参数:2018-06-08

import numpy as np 

X = np.array([[1,2], [4,5], [7,8]]) 

print(np.mean(X, axis=0, keepdims=True)) 

print(np.mean(X, axis=1, keepdims=True))

axis=0,那么输出矩阵是1行,求每一列的平均(按照每一行去求平均);axis=1,输出矩阵是1列,求每一行的平均(按照每一列去求平均)。还可以这么理解,axis是几,那就表明哪一维度被压缩成1。

再举个更复杂点的例子,比如我们输入为batch = [128, 28, 28],可以理解为batch=128,图片大小为28×28像素,我们相求这128个图片的均值,应该这么写

m = np.mean(batch, axis=0)

输出结果m的shape为(28,28),就是这128个图片在每一个像素点平均值。

你可能感兴趣的:(关于numpy mean函数的axis参数:2018-06-08)