numpy的维度增删函数——np.expand_dims()和np.squeeze()

目录

  • np.expand_dims()
  • np.squeeze()

numpy是python中重要的科学计算库,常用于数组或矩阵的计算,此时便会涉及到数组维度匹配问题,虽然numpy有broadcast机制,但为避免一些难以察觉的bug,有必要对数组的维度进行增删操作,以使数组维度相匹配。

numpy为用户提供了维度扩展函数np.expand_dims()和维度删减函数np.squeeze()。

np.expand_dims()

np.expand_dims(a, axis)

参数如下:

  • a : array_like
  • axis : int
    Position in the expanded axes where the new axis is placed.

该函数的作用是在指定轴axis上增加数组a的一个维度,即,在第“axis”维,加一个维度出来,原先在“axis”左边的维度保持位置不变,在“axis”右边的维度整体右移。

注意:该函数不改变输入数组a,而是产生一个新数组,新数组中的元素与原数组完全相同。

假设三维数组a的shape是(m, n, c),则

  • np.expand_dims(a, axis=0)表示在a的第一个维度上增加一个新的维度,而其他维度整体往右移,最终得到shape为(1, m, n, c)的新数组,新数组中的元素与原数组完全相同。
import numpy as np

a = np.reshape(list(range(24)), (2, 3, 4))
a_new = np.expand_dims(a, axis=0)
print('a =', a)
print('a_new =', a_new)
print('a.shape = ', a.shape)
print('a_new.shape = ', a_new.shape)

输出结果:

a = [[[ 0  1  2  3]
      [ 4  5  6  7]
      [ 8  9 10 11]]

     [[12 13 14 15]
      [16 17 18 19]
      [20 21 22 23]]]
      
a_new = [[[[ 0  1  2  3]
           [ 4  5  6  7]
           [ 8  9 10 11]]

          [[12 13 14 15]
           [16 17 18 19]
           [20 21 22 23]]]]
a.shape =  (2, 3, 4)
a_new.shape =  (1, 2, 3, 4)
  • np.expand_dims(a, axis=1)将得到shape为(m, 1, n, c)的新数组,新数组中的元素与原数组a完全相同。
  • np.expand_dims(a, axis=2)将得到shape为(m, n, 1, c)的新数组,新数组中的元素与原数组a完全相同。
  • np.expand_dims(a, axis=3)将得到shape为(m, n, c, 1)的新数组,新数组中的元素与原数组a完全相同。
import numpy as np

a = np.reshape(list(range(24)), (2, 3, 4))
print('a =', a)
print('np.expand_dims(a, axis=1) =', np.expand_dims(a, axis=1))
print('np.expand_dims(a, axis=2) =', np.expand_dims(a, axis=2))
print('np.expand_dims(a, axis=3) =', np.expand_dims(a, axis=3))
print('a.shape = ', a.shape)
print('np.expand_dims(a, axis=1).shape =', np.expand_dims(a, axis=1).shape)
print('np.expand_dims(a, axis=2).shape =', np.expand_dims(a, axis=2).shape)
print('np.expand_dims(a, axis=3).shape =', np.expand_dims(a, axis=3).shape)

输出结果:

a = [[[ 0  1  2  3]
      [ 4  5  6  7]
      [ 8  9 10 11]]
      
     [[12 13 14 15]
      [16 17 18 19]
      [20 21 22 23]]]
np.expand_dims(a, axis=1) = [[[[ 0  1  2  3]
                               [ 4  5  6  7]
                               [ 8  9 10 11]]]
                               
           				     [[[12 13 14 15]
  							   [16 17 18 19]
                               [20 21 22 23]]]]
np.expand_dims(a, axis=2) = [[[[ 0  1  2  3]]
 							  [[ 4  5  6  7]]
                              [[ 8  9 10 11]]]

                             [[[12 13 14 15]]
                              [[16 17 18 19]]
                              [[20 21 22 23]]]]
np.expand_dims(a, axis=3) = [[[[ 0]
                               [ 1]
                               [ 2]
   							   [ 3]]

 						      [[ 4]
  						       [ 5]
  							   [ 6]
 							   [ 7]]

							  [[ 8]
  							   [ 9]
  							   [10]
  							   [11]]]


							 [[[12]
							   [13]
							   [14]
							   [15]]
							
							  [[16]
							   [17]
							   [18]
							   [19]]
							
							  [[20]
							   [21]
							   [22]
							   [23]]]]
a.shape =  (2, 3, 4)
np.expand_dims(a, axis=1).shape = (2, 1, 3, 4)
np.expand_dims(a, axis=2).shape = (2, 3, 1, 4)
np.expand_dims(a, axis=3).shape = (2, 3, 4, 1)

np.squeeze()

squeeze(a, axis=None)

参数:

  • a : array_like
    Input data.
  • axis : None or int or tuple of ints, optional

该函数的作用是:删除输入数组a中维度为1的维度,并返回新的数组,新数组的元素与原数组a完全相同。(Remove single-dimensional entries from the shape of an array.)

>>> a = np.array([[[0], [1], [2]]])
>>> a.shape
(1, 3, 1)

# 未指定axis,则删除所有维度为1的维度
>>> np.squeeze(a)
[0, 1, 2]
>>> np.squeeze(a).shape
(3,)

# 指定axis=0,则删除该维度
>>> np.squeeze(a, axis=0)
[[0]
 [1]
 [2]]
>>> np.squeeze(a, axis=0).shape
(3, 1)

# 指定axis=2,则删除该维度
>>> np.squeeze(a, axis=2)
[[0 1 2]]
>>> np.squeeze(a, axis=2).shape
(3, 1)

# 同时指定axis=0和axis=2,则删除这两个维度
>>> np.squeeze(a, axis=(0, 2))
[0 1 2]
>>> np.squeeze(a, axis=(0, 2)).shape
(3,)

# 对于指定的axis,其维度必定为1,否则会报错
>>> np.squeeze(a, axis=1).shape
Traceback (most recent call last):
...
ValueError: cannot select an axis to squeeze out which has size not equal to one

你可能感兴趣的:(numpy的维度增删函数——np.expand_dims()和np.squeeze())