这个函数看的我头疼,翻来翻去看了好多人的解释,还是一头雾水
所以就自己写了点代码,记录一下
import numpy as np
a = np.array(range(1, 25)).reshape(2, 3, 4)
# 构造一个三维的数组,那么axis就可以取2了,axis最大只能取(维数-1)
array([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]])
# axis=0,发现结果和原来一样
np.stack(a, axis=0)
shape:(2, 3, 4)
array([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]])
# axis=1,结果是第1维里的2个数组,这2个数组里面相同下标的元素搞一块去了,当然还是三维
# [1, 2, 3, 4] <=> [13, 14, 15, 16]
np.stack(a, axis=1)
shape:(3, 2, 4)
array([[[ 1, 2, 3, 4],
[13, 14, 15, 16]],
[[ 5, 6, 7, 8],
[17, 18, 19, 20]],
[[ 9, 10, 11, 12],
[21, 22, 23, 24]]])
# axis=2,结果是第2维里的3个数组,这3个数组里面相同下标的元素搞一块去了,当然也还是三维
# [1] <=> [13]
np.stack(a, axis=2)
shape:(3, 4, 2)
array([[[ 1, 13],
[ 2, 14],
[ 3, 15],
[ 4, 16]],
[[ 5, 17],
[ 6, 18],
[ 7, 19],
[ 8, 20]],
[[ 9, 21],
[10, 22],
[11, 23],
[12, 24]]])