在numpy中axis是一个比较难理解的点,在很长一段时间我都是在处理一些2维的数组,所以往往对这块知识有所忽略,直到我在做斯坦福的cs231n的assignment时候,才对axis有了更加深入的理解。
先来看一组简单的代码。
import numpy as np
t = np.arange(8).reshape(2, 4)
print("origin: ", t, " shape: ", t.shape)
print("sum: ", t.sum(0), " shape: ", t.sum(0).shape)
print("sum: ", t.sum(1), " shape: ", t.sum(1).shape)
结果如下:
origin: [[0 1 2 3]
[4 5 6 7]] shape: (2, 4)
sum: [ 4 6 8 10] shape: (4,)
sum: [ 6 22] shape: (2,)
首先确定一点,axis=0是shape中从左往右数的第0个轴也就是(“2”,4)加引号部分,以此类推
我们要求和的axis就是将该轴消去 ,这里的t.sum(0)即消去0轴故只剩下 (4,),sum(1)同理
现在从结果来看,以sum(1)为例 写出下标 6: (0,x) 22 :(1,x) 这里写出x表示我们消去的轴
现在我们可以知道其实就是执行一个循环将所有0轴相同的元素加起来
6:(0,x) = 0(0,0)+1(0,1) +2(0,2)+3(0,3)
22:(1,x)=4(1,0)+5(1,1) +6(1,2)+7(1,3)
而 t.sum((0,1))=28(x,x)则是把两个维度的都加起来,简单不赘述
是不是还挺好理解的下面我们看下一组三维的数组
import numpy as np
t = np.arange(8).reshape(2, 2, 2)
print("origin: ", t, " shape: ", t.shape)
print("sum: ", t.sum(1), " shape: ", t.sum(0).shape)
print("sum: ", t.sum((1, 2)), " shape: ", t.sum(1).shape)
origin: [[[0 1]
[2 3]]
[[4 5]
[6 7]]] shape: (2, 2, 2)
sum: [[ 2 4]
[10 12]] shape: (2, 2)
sum: [ 6 22] shape: (2, 2)
这里我将下标一一标出
t :
0: (0,0,0)
1: (0,0,1)
2: (0,1,0)
3: (0,1,1)
4: (1,0,0)
5: (1,0,1)
6: (1,1,0)
7: (1,1,1)
t.sum(1):
2:(0,x,0)
4:(0,x,1)
10:(1,x,0)
12:(1,x,1)
相信看到这里已经很明了了,和2维的情况类似,x是被消去的轴,我们只要找到第0轴和第2轴相同元素相加即可
t.sum((1,2))
6:(0,x,x)
22:(1,x,x)
第0轴是其保留下来的轴,我们找到第0轴相同的元素全部相加,即可求得结果。
这里只写sum函数,其实mean,std,max等等函数都是一样的操作。
如果这篇文章对你有帮助,顺手点个赞,我会很开心!