np.stack(array,axis,out=None)
,函数原型。
其中最重要是的这个axis怎么理解的。
举例说明:
arrays = [np.random.randn(3, 4) for _ in range(10)]
会生成一个 10 *( 3 * 4 )的矩阵列表。十个矩阵,每个矩阵是(3 * 4)大小。
首先说明一下axis的映射。在这个例子中,10->axis=0 ,3->axis=1
>>>np.stack(arrays,axis=1)
array([[[-0.42233185, -0.13270788, -0.47724388, -1.48881134],
[ 0.2284937 , -0.30139984, 0.15633374, 0.04428078],
[ 2.0193316 , 0.1098357 , -0.32044757, -1.24868601],
[ 0.9859909 , -0.42781564, 0.57524126, 0.58154297],
[-0.13059124, 2.15207301, 0.36007904, -0.71344781],
[-1.68010975, 1.25350273, 0.11073033, -0.28531604],
[ 0.60021096, -0.18691447, 1.49261775, 0.47628294],
[-0.18268831, -0.32463742, -0.89726008, 0.19245843],
[-0.27384598, 0.56068318, 1.57096001, 1.11169077],
[ 0.27035354, -0.54258351, -0.69891459, 1.84282464]],
[[ 1.44874184, -1.6645958 , 1.14128754, -2.26945958],
[ 0.28754711, -1.59591539, -0.92798468, -0.05021877],
[ 1.09050239, -0.86881164, -0.59820951, -0.39628311],
[-1.09540304, -0.33438594, -0.71075442, -1.48691938],
[ 0.7155825 , 0.24710929, -0.65019501, -1.24407802],
[-0.11059045, -1.57851632, 1.34142995, -0.44438407],
[ 0.9258746 , 1.62418684, -0.25380587, -1.1423341 ],
[-1.76337136, 0.55031978, 1.25834475, 0.53257722],
[ 0.05755626, 1.16156935, -1.84999546, 1.57175386],
[ 0.48836813, -0.21907532, -0.78655392, 0.51705705]],
[[-0.24451876, -0.09881284, 1.17611246, 0.81276037],
[ 0.89510841, 0.9106155 , 0.4923826 , -0.07364133],
[-0.0670429 , 0.72968107, -1.31473173, -0.31313322],
[ 0.62314248, 0.97792175, 0.0840199 , -0.38035465],
[ 0.70222737, 0.53761069, 0.50546661, -2.02777762],
[-0.85454667, -0.76359383, -0.25280887, -0.94252057],
[ 0.38294622, -0.38729216, 0.03757319, -0.48955485],
[ 1.52718003, 1.14814816, 1.33147053, -0.50341043],
[-0.38600834, 0.19781327, -0.35596671, 1.59331045],
[-0.07073478, -1.4710414 , 1.95192939, -0.83379204]]])
>>> np.stack(arrays, axis=1).shape
(3, 10, 4)
为什么会变成 3 * 10 * 4了呢。首先我们的函数是对 10 * 3 * 4 中的3,也就是axis=1,进行了堆叠。
那么这个 axis = 1,在十个矩阵中代表什么呢?代表 每个矩阵中的一行。所以这个函数的操作就是,把10矩阵中的第i行拿出来拼成一个矩阵。因为一个矩阵有三行,所以堆叠后的矩阵就是,3 * 10 * 4,这个10 * 4,就是原来矩阵中,十个矩阵的第一行,第二行,第三行,拼接而成的。所以是 3 * 10 * 4。