广播(broadcasting)描述了不同 shape
的 NumPy arrays 之间该如何进行算术运算。广播虽然是一个非常强大的功能,但非常容易引起混淆。将标量与一个 array 相乘的运算其实就是一个非常简单的广播例子:
arr = np.arange(5)
arr * 4
"""
array([ 0, 4, 8, 12, 16])
"""
我们说标量值 4
被广播到了 array 中的所有元素。
再来看一个一维 array 与二维 array 之间的运算,给二维 array 的每列执行取均值操作:
arr = np.random.randn(4, 3)
arr.mean(0)
"""
array([-0.28253349, 0.53668684, 0.04200281])
"""
注意 mean
中的 0 表示的是沿 axis 0
的方向,即得到的是每列的均值,如果将 0 换成 1,那么我们将沿 axis 1
的方向,即得到的是每行的均值,长度为 4。
demeaned = arr - arr.mean(0)
demeaned
"""
array([[-0.73739998, -0.26250355, 0.3488015 ],
[ 2.96855138, 1.60747399, -0.01027794],
[-0.62175362, -1.27725946, -0.34278256],
[-1.60939777, -0.06771098, 0.004259 ]])
"""
这相当于我们在 axis 0
的方向对一个一维 array 进行了广播,计算过程可按下图来理解:
Broadcasting Rule Two arrays are compatible for broadcasting if for each trailing dimension (i.e., starting from the end) the axis lengths match or if either of the lengths is 1. Broadcasting is then performed over the missing or length 1 dimensions.
根据这个规则,对 arr
的行进行去均值操作可能就没那么容易了。因为 arr.mean(0)
的长度为 3,和 arr
的 trailing dimension 3 是相同的,因此可以直接广播。但 arr.mean(0)
的长度为 4,我们必须先把 shape
从 (4,)
变为 (4, 1)
:
arr - arr.mean(1).reshape((4, 1))
"""
array([[-0.90161818, 0.39249858, 0.5091196 ],
[ 1.06538336, 0.5235263 , -1.58890966],
[-0.25574062, -0.09202612, 0.34776674],
[-1.4330334 , 0.92787373, 0.50515967]])
"""
再看将一个 4×2 array 与一个 3×4×2 array 沿 axis 0
相加的例子,4×2 与 三维 array 对应的 trailing dimension 是相同的,因此可直接广播:
如果较低维度的 array 为 3×2,那么我们需先将其 shape
改为 (3, 1, 2)
。
正如你所见,我们可能常常需要添加一个新的长度为 1 的维度。除了 reshape
方法,我们还可以通过添加 np.newaxis
属性来实现。
arr = np.zeros((3, 3))
arr_3d = arr[:, np.newaxis, :]
arr_3d.shape
"""
(3, 1, 3)
"""
因此,如果我们有一个三维 array(3×4×5)并且想对 axis 2
去均值,可以这样写:
arr = np.random.randn(3, 4, 5)
depth_means = arr.mean(2)
depth_means.shape
"""
(3, 4)
"""
demeaned = arr - depth_means[:, :, np.newaxis]
广播也可用在对 array 进行赋值上,也要遵循上面的广播规则。我们有 array
arr = np.zeros((4, 3))
最简单的情况
arr[:] = 5
也是一种广播。
如果我们有一个一维 array col
,且想把当中的值赋给 arr
中的列:
col = np.array([1.28, -0.42, 0.44, 1.6])
arr[:] = col[:, np.newaxis]
"""
array([[ 1.28, 1.28, 1.28],
[-0.42, -0.42, -0.42],
[ 0.44, 0.44, 0.44],
[ 1.6 , 1.6 , 1.6 ]])
"""
Python for Data Analysis, 2 n d ^{\rm nd} nd edition. Wes McKinney.