pytorch 和 numpy 中 flatten() 用法说明

flatten用法

  • 1. n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten()
    • 1.1 n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten()
    • 1.2 n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten() n d a r r a y . r a v e l ( ) ndarray.ravel() ndarray.ravel()
    • 1.3 n d a r r a y . s q u e e z e ( ) ndarray.squeeze() ndarray.squeeze()
  • 2. t o r c h . f l a t t e n ( ) torch.flatten() torch.flatten()
    • 2.1 t o r c h . f l a t t e n ( ) torch.flatten() torch.flatten()
    • 2.2 t o r c h . s q u e e z e ( ) torch.squeeze() torch.squeeze() t o r c h . u n s q u e e z e ( ) torch.unsqueeze() torch.unsqueeze()

1. n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten()

1.1 n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten()

  • 适用于numpy对象
    flatten是numpy.ndarray.flatten的一个函数,即返回一个折叠成一维的数组。

Parameters:
ndarray.flatten(order=‘C’) Return a copy of the array collapsed into one dimension. order : {‘C’, ‘F’, ‘A’, ‘K’}, optional:

  • ‘C’ means to flatten in row-major (C-style) order.
  • ‘F’ means to flatten in column-major (Fortran- style) order.
  • ‘A’ means to flatten in column-major order if a is Fortran contiguous in memory, row-major order otherwise.
  • ‘K’ means to flatten a in the order the elements occur in memory.
    The default is ‘C’.
  • 举例:
    参数是 str , ‘C’: 按行展开, ’F‘: 按列展开
>>> import numpy as np
>>> a = np.array([[1,2,3],[4,5,6]])
>>> a
array([[1, 2, 3],
       [4, 5, 6]])
>>> a.flatten('C')
array([1, 2, 3, 4, 5, 6])
>>> a.flatten('F')
array([1, 4, 2, 5, 3, 6])
  • list 展开:
    list,使用列表表达式。

Python中flatten( ),matrix.A用法说明


1.2 n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten() n d a r r a y . r a v e l ( ) ndarray.ravel() ndarray.ravel()

n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten() n d a r r a y . r a v e l ( ) ndarray.ravel() ndarray.ravel() 都是对向量的展平操作,区别在于:

  • n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten() : 返回原数组副本
  • n d a r r a y . r a v e l ( ) ndarray.ravel() ndarray.ravel(): 不返回原数组副本
>>> a.flatten()[0] = -1
>>> a
array([[1, 2, 3],
       [4, 5, 6]])
>>> a.ravel()[0] = -1
>>> a
array([[-1,  2,  3],
       [ 4,  5,  6]])

1.3 n d a r r a y . s q u e e z e ( ) ndarray.squeeze() ndarray.squeeze()

n d a r r a y . s q u e e z e ( d i m = a ) ndarray.squeeze(dim = a) ndarray.squeeze(dim=a): 对维数为1的维数降维,, 可以指定维数

>>> b = np.array([[[1],[2]]])
>>> b
array([[[1],
        [2]]])
>>> b.shape
(1, 2, 1)
>>> b.squeeze()
array([1, 2])
>>> b.squeeze(0)
array([[1],
       [2]])
>>> b.squeeze(2)
array([[1, 2]])

2. t o r c h . f l a t t e n ( ) torch.flatten() torch.flatten()

2.1 t o r c h . f l a t t e n ( ) torch.flatten() torch.flatten()

t o r c h . f l a t t e n ( i n p u t , s t a r t d i m = 0 , e n d d i m = − 1 ) → T e n s o r torch.flatten(input, start_dim=0, end_dim=- 1) → Tensor torch.flatten(input,startdim=0,enddim=1)Tensor

可以指定开始、结束展开的维度。
如从 dim = 1 开始, shape: (2, 2, 2) -> (2, 4)

>>> t = torch.tensor([[[1, 2],
...                    [3, 4]],
...                   [[5, 6],
...                    [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

TORCH.FLATTEN


2.2 t o r c h . s q u e e z e ( ) torch.squeeze() torch.squeeze() t o r c h . u n s q u e e z e ( ) torch.unsqueeze() torch.unsqueeze()

torch 中也提供 squeeze 函数, 同 numpy
torch.squeeze(input, dim=None) → Tensor

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])

此外, torch 中还提供了 torch.unsqueeze,扩充维度。
t o r c h . u n s q u e e z e ( i n p u t , d i m ) → T e n s o r torch.unsqueeze(input, dim) → Tensor torch.unsqueeze(input,dim)Tensor

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])

你可能感兴趣的:(pytorch日常,numpy,pytorch,python)