推荐使用Numba加速Python科学计算

因为正在用 Python 写 lattice Boltzmann 的原因,我时不时也会研究 Python 科学计算程序的加速方法。现在为止,陆续尝试过了几个方案,包括 numexpr、Cython 等等,也写了一些博文(1、2)。而这篇文章,则是要作一个推荐,建议所有在使用 Python 做科学计算的人尝试一下 Numba。

推荐的原因,首先得从 Python 科学计算说起。

Python 本身由于其动态语言的本质,并不适合于计算量很大的科学计算。然而 Numpy 提供的带类型的数据结构,以及预编译好的基于 C 或 Fortran 的高速计算库,很大程度上解决了这个问题。绝大部分基于 Python 的科学计算程序,都是将其数据储存在 ndarray 里的。按我自己目前的认识,Numpy 至少带来了两个方面的好处,一是存储的数据带有类型,Python 不再需要动态地猜测变量的数据类型;二是提供了一系列高速的 ufunc,可以快速地对大规模的数组进行运算操作。

Numpy 提供的 ndarray 以及 ufunc 在有些情况下是足够应付一些简单的计算逻辑的,比如 Palabos 官网提供的圆柱绕流的代码。按 Palabos 自己的说法,这个代码短得惊人,而且速度也非常快。

但是,总会有一些情况是用简单的数组没法完成的。至少有两种典型的情况。其一是沿时间的迭代,后一步依赖前一步的结果。即使每一个时间步中的计算量都不大,但所有的时间步不能或者很难做成一个简单的数组计算。另一种情况是确实是基于数组的计算,但计算逻辑比较复杂,没法用简单的数组算式来完成,比如 LBM 中的碰撞迁移如果有非规则障碍物就很难用数组简单表达了。

这种时候,最直接的解决办法是,回到写 C 的方法中来,写循环然后一个元素一个元素地算。Numba、Numexpr 属于直接在 Python 程序基础上作修改来实现加速的,Cython 以及直接写 Fortran 或 C 扩展都属于充分利用 Python 脱水语言特性的办法。按照这篇13年的文章的测试,它们的速度没有太大的区别。

如果速度没有太大区别的话,对我们这种应用语言来计算,而不是专门研究高效计算的人来说,简便、可靠成了最重要的考量因素。而这,也正是我推荐 Numba 的最主要的原因。与 Cython 相比,Numba 在编写的时候与 Cython 是非常类似的,甚至比它还要简洁(不等于简单),因为不用声明变量类型。

然而 Numba 有一些非常明显的优势。

首先,Cython 有一些普通人不一定知道的优化技巧,比如关闭 boundscheck 等,Numba 中则不需要考虑这些问题。前面提到的那篇文章的结果来看,Cython 即使高度优化,在速度上也没有优势,甚至可能更慢。

然后,Numba 在语言层面上还完全是 Python,而不是像 Cython 一样是一种杂交的语言。当然两者都基本是按照 C 的逻辑在写,比如大量出现的嵌套循环(这在纯 Python 脚本中是不可想象的)。

如果说上一条还是一个洁癖的话,这一条则是 Numba 明显优势的地方:Numba 对源程序的修改很小,几乎是只需要加一个 @jit 修饰符就可以了。而 Cython 相对来说就麻烦不少,要额外编译,pyx 文件也有一些讲究,比如还需要 cimport numpy 之类。

最后,Numba 可以以 GPU 或多核 CPU 为目标编译代码,实现方式同样十分简单。

简而言之,Numba 提供的方案,不比别的速度慢,但实现起来要方便不少。既然如此,为何不用它呢?我反正已经全面从 Cython 迁移过来了。

你可能感兴趣的:(推荐使用Numba加速Python科学计算)