原始数组
In [1]: import numpy as np
In [2]: a = np.array([[[1,2,3],[4,5,6]]])
In [3]: a.shape
Out[3]: (1, 2, 3)
np.expand_dims(a, axis=0)表示在0位置添加数据,转换结果如下:
In [4]: b = np.expand_dims(a, axis=0)
In [5]: b
Out[5]:
array([[[[1, 2, 3],
[4, 5, 6]]]])
In [6]: b.shape
Out[6]: (1, 1, 2, 3)
np.expand_dims(a, axis=1)表示在1位置添加数据,转换结果如下:
In [7]: c = np.expand_dims(a ,axis=1)
In [8]: c
Out[8]:
array([[[[1, 2, 3],
[4, 5, 6]]]])
In [9]: c.shape
Out[9]: (1, 1, 2, 3)
np.expand_dims(a, axis=2)表示在2位置添加数据,转换结果如下:
In [10]: d = np.expand_dims(a, axis=2)
In [11]: d
Out[11]:
array([[[[1, 2, 3]],
[[4, 5, 6]]]])
In [12]: d.shape
Out[12]: (1, 2, 1, 3)
np.expand_dims(a, axis=3)表示在3位置添加数据,转换结果如下:
In [13]: e = np.expand_dims(a, axis=3)
In [14]: e
Out[14]:
array([[[[1],
[2],
[3]],
[[4],
[5],
[6]]]])
In [15]: e.shape
Out[15]: (1, 2, 3, 1)
np.expand_dims(a, axis=1)表示在>=4位置添加数据,转换结果如下:
In [16]: f = np.expand_dims(a, axis=4)
---------------------------------------------------------------------------
AxisError Traceback (most recent call last)
<ipython-input-16-726148459f80> in <module>
----> 1 f = np.expand_dims(a, axis=4)
<__array_function__ internals> in expand_dims(*args, **kwargs)
~/.local/lib/python3.6/site-packages/numpy/lib/shape_base.py in expand_dims(a, axis)
595
596 out_ndim = len(axis) + a.ndim
--> 597 axis = normalize_axis_tuple(axis, out_ndim)
598
599 shape_it = iter(a.shape)
~/.local/lib/python3.6/site-packages/numpy/core/numeric.py in normalize_axis_tuple(axis, ndim, argname, allow_duplicate)
1325 pass
1326 # Going via an iterator directly is slower than via list comprehension.
-> 1327 axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
1328 if not allow_duplicate and len(set(axis)) != len(axis):
1329 if argname:
~/.local/lib/python3.6/site-packages/numpy/core/numeric.py in <listcomp>(.0)
1325 pass
1326 # Going via an iterator directly is slower than via list comprehension.
-> 1327 axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
1328 if not allow_duplicate and len(set(axis)) != len(axis):
1329 if argname:
AxisError: axis 4 is out of bounds for array of dimension 4
返回一维数组中最大值的索引值:
In [4]: a = np.arange(6).reshape(1,6)
In [5]: a
Out[5]: array([[0, 1, 2, 3, 4, 5]])
In [6]: b = np.argmax(a)
In [7]: b
Out[7]: 5
axis=0表示沿行比较,输出最大值索引;
axis=1表示沿列比较,输出最大值索引。
In [13]: c = np.array([[1,4,3],[2,1,4]])
In [14]: c
Out[14]:
array([[1, 4, 3],
[2, 1, 4]])
In [16]: d = np.argmax(c, axis=0)
In [15]: d = np.argmax(c, axis=0)
In [16]: d
Out[16]: array([1, 0, 1])
In [17]: e = np.argmax(c, axis=1)
In [18]: e
Out[18]: array([1, 2])
三维、四维同理。