其实感觉expand_dims(a, axis)
就是在axis的那一个轴上把数据加上去,这个数据在axis这个轴的0位置。
例如原本为一维的2个数据,axis=0,则shape变为(1,2),axis=1则shape变为(2,1)
再例如 原本为 (2,3),axis=0,则shape变为(1,2,3),axis=1则shape变为(2,1,3)
help(np.expand_dims)
Help on function expand_dims in module numpy.lib.shape_base:
expand_dims(a, axis)
Expand the shape of an array.
Insert a new axis, corresponding to a given position in the array shape.
Parameters
----------
a : array_like
Input array.
axis : int
Position (amongst axes) where new axis is to be inserted.
Returns
-------
res : ndarray
Output array. The number of dimensions is one greater than that of
the input array.
See Also
--------
doc.indexing, atleast_1d, atleast_2d, atleast_3d
Examples
--------
>>> x = np.array([1,2])
>>> x.shape
(2,)
The following is equivalent to ``x[np.newaxis,:]`` or ``x[np.newaxis]``:
>>> y = np.expand_dims(x, axis=0)
>>> y
array([[1, 2]])
>>> y.shape
(1, 2)
>>> y = np.expand_dims(x, axis=1) # Equivalent to x[:,newaxis]
>>> y
array([[1],
[2]])
>>> y.shape
(2, 1)
Note that some examples may use ``None`` instead of ``np.newaxis``. These
are the same objects:
>>> np.newaxis is None
True
x = np.array([1,2,3])
print x
print x.shape
[1 2 3]
(3,)
y = np.expand_dims(x,axis=0)
print y
print "y.shape: ",y.shape
print "y[0][1]: ",y[0][1]
[[1 2 3]]
y.shape: (1, 3)
y[0][1]: 2
y = np.expand_dims(x,axis=1)
print y
print "y.shape: ",y.shape
print "y[1][0]: ",y[1][0]
[[1]
[2]
[3]]
y.shape: (3, 1)
y[1][0]: 2
y = np.expand_dims(x,axis=3)
print y
print "y.shape: ",y.shape
print "y[2][0]: ",y[2][0]
[[1]
[2]
[3]]
y.shape: (3, 1)
y[2][0]: 3
x = np.array([[1,2,3],[4,5,6]])
print x
print x.shape
[[1 2 3]
[4 5 6]]
(2, 3)
y = np.expand_dims(x,axis=0)
print y
print "y.shape: ",y.shape
print "y[0][1]: ",y[0][1]
[[[1 2 3]
[4 5 6]]]
y.shape: (1, 2, 3)
y[0][1]: [4 5 6]
y = np.expand_dims(x,axis=1)
print y
print "y.shape: ",y.shape
print "y[1][0]: ",y[1][0]
[[[1 2 3]]
[[4 5 6]]]
y.shape: (2, 1, 3)
y[1][0]: [4 5 6]
y = np.expand_dims(x,axis=3)
print y
print "y.shape: ",y.shape
print "y[2][0]: ",y[2][0]
[[[1]
[2]
[3]]
[[4]
[5]
[6]]]
y.shape: (2, 3, 1)
y[2][0]:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
in ()
2 print y
3 print "y.shape: ",y.shape
----> 4 print "y[2][0]: ",y[2][0]
IndexError: index 2 is out of bounds for axis 0 with size 2