Python中的np.split与MXNet中的nd.split的一些用法区别

我们先来看下np.split的实现方法:

@array_function_dispatch(_split_dispatcher)
def split(ary, indices_or_sections, axis=0):
    try:
        len(indices_or_sections)
    except TypeError:
        sections = indices_or_sections
        N = ary.shape[axis]
        if N % sections:
            raise ValueError(
                'array split does not result in an equal division') from None
    return array_split(ary, indices_or_sections, axis)

当然有兴趣的可以继续看array_split的具体操作方法。

从split的定义可以看到,参数是(数组,数或数组,维度)返回值是列表,里面的每组元素是数组。看示例:
参数是整数

import numpy as np

a=np.arange(10)
np.split(a,5)
#[array([0, 1]), array([2, 3]), array([4, 5]), array([6, 7]), array([8, 9])]

将0-9,均分成5份。如果不能被整除,比如是4,将出现错误:
ValueError: array split does not result in an equal division

 参数是列表,按照里面每个值来分段

a=np.arange(10)
np.split(a,[4,8])
#[array([0, 1, 2, 3]), array([4, 5, 6, 7]), array([8, 9])]

变形成二维数组来切分

a=np.arange(40).reshape(8,5)
np.split(a,4)
'''
[array([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]]),
 array([[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]]),
 array([[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]]),
 array([[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]])]
'''

axis=1的结果

a=np.arange(40).reshape(5,8)
np.split(a,4,axis=1)
'''
[array([[ 0,  1],
        [ 8,  9],
        [16, 17],
        [24, 25],
        [32, 33]]),
 array([[ 2,  3],
        [10, 11],
        [18, 19],
        [26, 27],
        [34, 35]]),
 array([[ 4,  5],
        [12, 13],
        [20, 21],
        [28, 29],
        [36, 37]]),
 array([[ 6,  7],
        [14, 15],
        [22, 23],
        [30, 31],
        [38, 39]])]
'''

接下来的nd.split的用法跟np.split虽然用法很像,还是存在一些区别需要注意。

还是贴下nd.split的方法:

def split(data=None, num_outputs=_Null, axis=_Null, squeeze_axis=_Null, out=None, name=None, **kwargs):
   return (0,)

跟np.split的区别就是必须指定axis,不然会报错。

from mxnet import nd

a=nd.arange(40).reshape(8,5)
nd.split(a,4,axis=0)

'''
[
 [[0. 1. 2. 3. 4.]
  [5. 6. 7. 8. 9.]]
 ,
 
 [[10. 11. 12. 13. 14.]
  [15. 16. 17. 18. 19.]]
 ,
 
 [[20. 21. 22. 23. 24.]
  [25. 26. 27. 28. 29.]]
 ,
 
 [[30. 31. 32. 33. 34.]
  [35. 36. 37. 38. 39.]]
 ]
'''

三维的例子亦如是,如:切分第二维切4份

a=nd.arange(40).reshape(2,4,5)
nd.split(a,4,axis=1)
'''
[
 [[[ 0.  1.  2.  3.  4.]]
 
  [[20. 21. 22. 23. 24.]]]
 ,
 
 [[[ 5.  6.  7.  8.  9.]]
 
  [[25. 26. 27. 28. 29.]]]
 ,
 
 [[[10. 11. 12. 13. 14.]]
 
  [[30. 31. 32. 33. 34.]]]
 ,
 
 [[[15. 16. 17. 18. 19.]]
 
  [[35. 36. 37. 38. 39.]]]
 ]
'''

每个元素的形状是nd.split(a,4,axis=1)[1].shape #(2,1,5)

除了参数名称不一样,个数也不一样,比如squeeze_axis这个新增的参数,可以减掉一维。

a=nd.arange(40).reshape(2,4,5)
nd.split(a,2,axis=0,squeeze_axis=1)
'''
[
 [[ 0.  1.  2.  3.  4.]
  [ 5.  6.  7.  8.  9.]
  [10. 11. 12. 13. 14.]
  [15. 16. 17. 18. 19.]]
 ,
 
 [[20. 21. 22. 23. 24.]
  [25. 26. 27. 28. 29.]
  [30. 31. 32. 33. 34.]
  [35. 36. 37. 38. 39.]]
 ]
'''

看出有什么不同了吗?少了一维,本来里面每个元素是 1x4x5 @cpu(0)>,现在变为
再看一例:

a=nd.arange(40).reshape(2,4,5)
nd.split(a,4,axis=1,squeeze_axis=1)
'''
[
 [[ 0.  1.  2.  3.  4.]
  [20. 21. 22. 23. 24.]]
 ,
 
 [[ 5.  6.  7.  8.  9.]
  [25. 26. 27. 28. 29.]]
 ,
 
 [[10. 11. 12. 13. 14.]
  [30. 31. 32. 33. 34.]]
 ,
 
 [[15. 16. 17. 18. 19.]
  [35. 36. 37. 38. 39.]]
 ]
'''

如果没有squeeze_axis=1这个参数,里面的元素形状是1x5 @cpu(0)>, 现在变为
所以这个其实就是将所在切分的维,有且仅有1,那么就减掉这个维度。这个其实是有意义的,毕竟属于没数据的占着空的维度,可以去掉。

你可能感兴趣的:(Python,深度学习框架(MXNet),np.split,nd.split,squeeze_axis=1)