Numpy:np.expand_dims()&np.argmax() 用法

np.expand_dims:用于扩展数组的形状

原始数组

In [1]: import numpy as np                                                                                                                                                               

In [2]: a = np.array([[[1,2,3],[4,5,6]]])                                                                                                                                                

In [3]: a.shape                                                                                                                                                                          
Out[3]: (1, 2, 3)

np.expand_dims(a, axis=0)表示在0位置添加数据,转换结果如下:

In [4]: b = np.expand_dims(a, axis=0)                                                                                                                                                    

In [5]: b                                                                                                                                                                                
Out[5]: 
array([[[[1, 2, 3],
         [4, 5, 6]]]])

In [6]: b.shape                                                                                                                                                                          
Out[6]: (1, 1, 2, 3)

np.expand_dims(a, axis=1)表示在1位置添加数据,转换结果如下:

In [7]: c = np.expand_dims(a ,axis=1)                                                                                                                                                    

In [8]: c                                                                                                                                                                                
Out[8]: 
array([[[[1, 2, 3],
         [4, 5, 6]]]])

In [9]: c.shape                                                                                                                                                                          
Out[9]: (1, 1, 2, 3)

np.expand_dims(a, axis=2)表示在2位置添加数据,转换结果如下:

In [10]: d = np.expand_dims(a, axis=2)                                                                                                                                                   

In [11]: d                                                                                                                                                                               
Out[11]: 
array([[[[1, 2, 3]],

        [[4, 5, 6]]]])

In [12]: d.shape                                                                                                                                                                         
Out[12]: (1, 2, 1, 3)

np.expand_dims(a, axis=3)表示在3位置添加数据,转换结果如下:

In [13]: e = np.expand_dims(a, axis=3)                                                                                                                                                   

In [14]: e                                                                                                                                                                               
Out[14]: 
array([[[[1],
         [2],
         [3]],

        [[4],
         [5],
         [6]]]])

In [15]: e.shape                                                                                                                                                                         
Out[15]: (1, 2, 3, 1)

np.expand_dims(a, axis=1)表示在>=4位置添加数据,转换结果如下:

In [16]: f = np.expand_dims(a, axis=4)                                                                                                                                                   
---------------------------------------------------------------------------
AxisError                                 Traceback (most recent call last)
<ipython-input-16-726148459f80> in <module>
----> 1 f = np.expand_dims(a, axis=4)

<__array_function__ internals> in expand_dims(*args, **kwargs)

~/.local/lib/python3.6/site-packages/numpy/lib/shape_base.py in expand_dims(a, axis)
    595 
    596     out_ndim = len(axis) + a.ndim
--> 597     axis = normalize_axis_tuple(axis, out_ndim)
    598 
    599     shape_it = iter(a.shape)

~/.local/lib/python3.6/site-packages/numpy/core/numeric.py in normalize_axis_tuple(axis, ndim, argname, allow_duplicate)
   1325             pass
   1326     # Going via an iterator directly is slower than via list comprehension.
-> 1327     axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
   1328     if not allow_duplicate and len(set(axis)) != len(axis):
   1329         if argname:

~/.local/lib/python3.6/site-packages/numpy/core/numeric.py in <listcomp>(.0)
   1325             pass
   1326     # Going via an iterator directly is slower than via list comprehension.
-> 1327     axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
   1328     if not allow_duplicate and len(set(axis)) != len(axis):
   1329         if argname:

AxisError: axis 4 is out of bounds for array of dimension 4

np.argmax:返回沿轴最大值的索引值

一维数组

返回一维数组中最大值的索引值:

In [4]: a = np.arange(6).reshape(1,6)                                                                                                                                                    

In [5]: a                                                                                                                                                                                
Out[5]: array([[0, 1, 2, 3, 4, 5]])
In [6]: b = np.argmax(a)                                                                                                                                                                 

In [7]: b                                                                                                                                                                                
Out[7]: 5

二维数组

axis=0表示沿行比较,输出最大值索引;
axis=1表示沿列比较,输出最大值索引。

In [13]: c = np.array([[1,4,3],[2,1,4]])                                                                                                                                                 

In [14]: c                                                                                                                                                                               
Out[14]: 
array([[1, 4, 3],
       [2, 1, 4]])
       In [16]: d = np.argmax(c, axis=0) 

In [15]: d = np.argmax(c, axis=0)                                                                                                                                                        

In [16]: d                                                                                                                                                                               
Out[16]: array([1, 0, 1])

In [17]: e = np.argmax(c, axis=1)                                                                                                                                                        

In [18]: e                                                                                                                                                                               
Out[18]: array([1, 2])                                                                                                                     

三维、四维同理。

你可能感兴趣的:(Numpy)