【Numba】加速计算

一、Numba是什么:

Numba可以将numpy的代码‘即时编译’成机器码,以获得近似原生机器码的计算速度

Numba is a just-in-time compiler for Python that works best on code that uses NumPy arrays and functions, and loops. The most common way to use Numba is through its collection of decorators that can be applied to your functions to instruct Numba to compile them. When a call is made to a Numba-decorated function it is compiled to machine code “just-in-time” for execution and all or part of your code can subsequently run at native machine code speed!

  • Pros:
    • 可充分利用cpu资源;通过cuda编程可利用gpu计算
    • 极大加速大矩阵的计算,矩阵越大提速越明显,提速一到两个数量级
    • 对于循环可实现平行计算
  • Cons:
    • 排序略比numpy要慢
    • 小规模计算提升不大
    • 包依赖和数据类型有严格限制(参见注意)

二、快速上手:

官方入门教程

https://numba.readthedocs.io/en/stable/user/5minguide.html

安装

# pip
pip install numba
# conda
conda install numba

简单例子:

import numpy as np
from numba import njit
# 假设想计算1w个人脸与人脸库(总量1w)的特征距离
"""
numpy实现
"""

compare_1w = np.random.rand(10000, 512)
base_10w = np.random.rand(512, 10000)
distance = np.dot(compare_1w, base_10w)
"""
numba实现
"""

# 简单的将计算步骤抽象成函数,并在函数外装饰 @njit()@njit()
def cal_dot(a, b):
    return np.dot(a, b)
distance = cal_dot(compare_1w, base_10w)

"""结果:速度提升约30倍
numpy cost: 0:00:45.465711
numba cost: 0:00:01.553018
"""

三、注意:各种报错及解决办法

  1. 当装饰njit(),函数不支持传入dict,function
    • 解决办法:可以先用numba得到原始的数据,再在函数外对numba的结果进行处理,这时就可以使用dict和function作为传入参数了
  2. 当装饰njit(),函数不支持continue,break,try,except
    • 解决办法:优化流程,仅使用for,if,else
  3. 当装饰njit()
    • 现象:若声明变量a = [(),[],[]]*10

      numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
      No implementation of function Function(<built-in function setitem>) found for signature:
       >>> setitem(list(Tuple())<iv=None>, int64, Tuple(int64, array(int64, 1d, A), list(float32)<iv=None>))
      
    • 解决办法:想办法把数据结构改成数组,或分成多个数组存放。数组需要事先声明大小

  4. 当装饰njit()String.format()
    • 现象

      numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
      Unknown attribute 'format' of type Literal[str]([*] Calculate img_type:{} similarity, batch {}/10, progress: {}/{})
      
    • 解决办法:去掉format,改用字符串拼接

  5. 当装饰njit(paralle=True)List.append()
    • 现象:append线程不安全

      File "/******/python3.9/site-packages/numba/cpython/listobj.py", line 1129, in list_to_list
          assert fromty.dtype == toty.dtype
      AssertionError
      
    • 解决办法:移除被装饰函数中的append(),改用index访问,预先声明list/np.ndarray的大小

  6. 当装饰njit(paralle=True)
    • 现象:

      /*******/python3.9/site-packages/numba/np/ufunc/parallel.py:365: NumbaWarning: The TBB threading layer requires TBB version 2019.5 or later i.e., TBB_INTERFACE_VERSION >= 11005. Found TBB_INTERFACE_VERSION = 9107. 
      The TBB threading layer is disabled. warnings.warn(problem)
      
    • 解决办法:https://github.com/numba/numba/issues/6350

      conda install tbb
      # or
      pip install --upgrade tbb
      

四、Appendix: 速度测试代码⏱️

import numpy as np
from numba import jit, njit
from datetime import datetime
# test1

@jit(nopython=True)
def cal(a, b):
    c = np.dot(b, a)
    d = np.max(c)
    
@njit(parallel=True)
def gen(n):
    a = np.random.rand(512, 10**n)
    b = np.random.rand(100,512)
    return a, b
def numba_cal(a, b):
    start = datetime.now()
    cal(a, b)
    end = datetime.now()
    return end - start
def np_cal(a, b):
    start = datetime.now()
    np.dot(b, a)
    end = datetime.now()
    return end - start
for i in range(8):
    a, b = gen(i)
    delta = numba_cal(a, b)
    print("scale {}: numba cost: {}".format(10**i, delta))
    delta = np_cal(a, b)
    print("scale {}: np cost: {}".format(10**i, delta))

你可能感兴趣的:(python技术,python,numpy)