我们先来看下np.split的实现方法:
@array_function_dispatch(_split_dispatcher)
def split(ary, indices_or_sections, axis=0):
try:
len(indices_or_sections)
except TypeError:
sections = indices_or_sections
N = ary.shape[axis]
if N % sections:
raise ValueError(
'array split does not result in an equal division') from None
return array_split(ary, indices_or_sections, axis)
当然有兴趣的可以继续看array_split的具体操作方法。
从split的定义可以看到,参数是(数组,数或数组,维度)返回值是列表,里面的每组元素是数组。看示例:
参数是整数
import numpy as np
a=np.arange(10)
np.split(a,5)
#[array([0, 1]), array([2, 3]), array([4, 5]), array([6, 7]), array([8, 9])]
将0-9,均分成5份。如果不能被整除,比如是4,将出现错误:
ValueError: array split does not result in an equal division
参数是列表,按照里面每个值来分段
a=np.arange(10)
np.split(a,[4,8])
#[array([0, 1, 2, 3]), array([4, 5, 6, 7]), array([8, 9])]
变形成二维数组来切分
a=np.arange(40).reshape(8,5)
np.split(a,4)
'''
[array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]]),
array([[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]]),
array([[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]]),
array([[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39]])]
'''
axis=1的结果
a=np.arange(40).reshape(5,8)
np.split(a,4,axis=1)
'''
[array([[ 0, 1],
[ 8, 9],
[16, 17],
[24, 25],
[32, 33]]),
array([[ 2, 3],
[10, 11],
[18, 19],
[26, 27],
[34, 35]]),
array([[ 4, 5],
[12, 13],
[20, 21],
[28, 29],
[36, 37]]),
array([[ 6, 7],
[14, 15],
[22, 23],
[30, 31],
[38, 39]])]
'''
接下来的nd.split的用法跟np.split虽然用法很像,还是存在一些区别需要注意。
还是贴下nd.split的方法:
def split(data=None, num_outputs=_Null, axis=_Null, squeeze_axis=_Null, out=None, name=None, **kwargs):
return (0,)
跟np.split的区别就是必须指定axis,不然会报错。
from mxnet import nd
a=nd.arange(40).reshape(8,5)
nd.split(a,4,axis=0)
'''
[
[[0. 1. 2. 3. 4.]
[5. 6. 7. 8. 9.]]
,
[[10. 11. 12. 13. 14.]
[15. 16. 17. 18. 19.]]
,
[[20. 21. 22. 23. 24.]
[25. 26. 27. 28. 29.]]
,
[[30. 31. 32. 33. 34.]
[35. 36. 37. 38. 39.]]
]
'''
三维的例子亦如是,如:切分第二维切4份
a=nd.arange(40).reshape(2,4,5)
nd.split(a,4,axis=1)
'''
[
[[[ 0. 1. 2. 3. 4.]]
[[20. 21. 22. 23. 24.]]]
,
[[[ 5. 6. 7. 8. 9.]]
[[25. 26. 27. 28. 29.]]]
,
[[[10. 11. 12. 13. 14.]]
[[30. 31. 32. 33. 34.]]]
,
[[[15. 16. 17. 18. 19.]]
[[35. 36. 37. 38. 39.]]]
]
'''
每个元素的形状是nd.split(a,4,axis=1)[1].shape #(2,1,5)
除了参数名称不一样,个数也不一样,比如squeeze_axis这个新增的参数,可以减掉一维。
a=nd.arange(40).reshape(2,4,5)
nd.split(a,2,axis=0,squeeze_axis=1)
'''
[
[[ 0. 1. 2. 3. 4.]
[ 5. 6. 7. 8. 9.]
[10. 11. 12. 13. 14.]
[15. 16. 17. 18. 19.]]
,
[[20. 21. 22. 23. 24.]
[25. 26. 27. 28. 29.]
[30. 31. 32. 33. 34.]
[35. 36. 37. 38. 39.]]
]
'''
看出有什么不同了吗?少了一维,本来里面每个元素是
再看一例:
a=nd.arange(40).reshape(2,4,5)
nd.split(a,4,axis=1,squeeze_axis=1)
'''
[
[[ 0. 1. 2. 3. 4.]
[20. 21. 22. 23. 24.]]
,
[[ 5. 6. 7. 8. 9.]
[25. 26. 27. 28. 29.]]
,
[[10. 11. 12. 13. 14.]
[30. 31. 32. 33. 34.]]
,
[[15. 16. 17. 18. 19.]
[35. 36. 37. 38. 39.]]
]
'''
如果没有squeeze_axis=1这个参数,里面的元素形状是
所以这个其实就是将所在切分的维,有且仅有1,那么就减掉这个维度。这个其实是有意义的,毕竟属于没数据的占着空的维度,可以去掉。