Numpy的Universal functions 中要求输入的数组shape是一致的,当数组的shape不想等的时候,则会使用广播机制,调整数组使得shape一样,满足规则,则可以运算,否则就出错
四条规则如下:
中文
以下通过实例来说明这些问题
一般情况下,numpy 都是采用一一对应的方式(element-by-element )进行计算
例子1:
>>> from numpy import array
>>> a = array([1.0,2.0,3.0])
>>> b = array([2.0,2.0,2.0])
>>> a * b
array([ 2., 4., 6.])
当不相等时,则会采用规则对其:
>>> from numpy import array
>>> a = array([1.0,2.0,3.0])
>>> b = 2.0
>>> a * b
array([ 2., 4., 6.])
a.shape得到的是(3,) b是一个浮点数,如果转换成array,则b.shape是一个(),a的1轴对齐,补齐为1,a.shape(3,1),b对齐,则对齐也为(3,1),然后按照一一对应的方式计算
或许上述例子不是太明确,下面采用一个更加确切的例子说明:
>>> import numpy as np
>>> a = np.arange(0, 6).reshape(6, 1)
>>> a
array([[ 0], [1], [2], [3], [4], [5]])
>>> a.shape
(6, 1)
>>> b = np.arange(0, 5)
>>> b.shape
(5,)
>>> c = a + b
>>> print c
[[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]
[5 6 7 8 9]]
在上述实例中,当使用+运算时,由于shape不一致,按照规则1,会对b进行
reshape,b.reshape成(1,5),可能会问为什么不是(5,1),因为这个就不能计算了,那么如果b的shape是(6,)的时候呢,都可以运算;所以对于对齐本身我自己是没有理解太过透彻,所以,我找了一下官方文档,其中的一个图是这样的:
我理解是,其本身的形状是不能改变的,只能在原来的基础上延伸,像上述的例子中,如果b的shape是(6,),如果在broadcasting的时候reshape(6,1)则已经是属于改变了原来的数组的形状,进行了翻转,而不是延伸。
接着上述实例,对于b则reshape成了(1,5),a则保持(6,1),按照规则2,则输出为每个轴上的最大值,则c.shape为(6,5);
对于规则3和规则4,都是在描述延伸的条件和方式,所以对于我的理解我也更加确信了,如果有大侠觉得有问题,请帮忙指正
参考:
NumPy-快速处理数据(http://old.sebug.net/paper/books/scipydoc/numpy_intro.html)
Array Broadcasting in numpy(http://scipy.github.io/old-wiki/pages/EricsBroadcastingDoc)
NumPy Reference(http://docs.scipy.org/doc/numpy-1.10.0/reference/ufuncs.html)