numpy中的squeeze()函数

numpy.squeeze(a, axis=None)

squeeze()函数的功能是:从矩阵shape中,去掉维度为1的。例如一个矩阵是的shape是(5, 1),使用过这个函数后,结果为(5,)。

参数
a是输入的矩阵
axis : 选择shape中的一维条目的子集。如果在shape大于1的情况下设置axis,则会引发错误。

栗子
要使用numpy先导入numpy库
import numpy as np


>>> x = np.array([[[0], [1], [2]]])
>>> x.shape
(1, 3, 1)
>>> np.squeeze(x).shape
(3,)
>>> np.squeeze(x, axis=(2,)).shape
(1, 3)

squeeze()的源码

def squeeze(a, axis=None):
    """
    Remove single-dimensional entries from the shape of an array.
    Parameters
    ----------
    a : array_like
        Input data.
    axis : None or int or tuple of ints, optional
        .. versionadded:: 1.7.0
        Selects a subset of the single-dimensional entries in the
        shape. If an axis is selected with shape entry greater than
        one, an error is raised.
    Returns
    -------
    squeezed : ndarray
        The input array, but with all or a subset of the
        dimensions of length 1 removed. This is always `a` itself
        or a view into `a`.
    Examples
    --------
    >>> x = np.array([[[0], [1], [2]]])
    >>> x.shape
    (1, 3, 1)
    >>> np.squeeze(x).shape
    (3,)
    >>> np.squeeze(x, axis=(2,)).shape
    (1, 3)
    """
    try:
        squeeze = a.squeeze
    except AttributeError:
        return _wrapit(a, 'squeeze')
    try:
        # First try to use the new axis= parameter
        return squeeze(axis=axis)
    except TypeError:
        # For backwards compatibility
return squeeze()

你可能感兴趣的:(numpy)