Numba可以将numpy的代码‘即时编译’成机器码,以获得近似原生机器码的计算速度
Numba is a just-in-time compiler for Python that works best on code that uses NumPy arrays and functions, and loops. The most common way to use Numba is through its collection of decorators that can be applied to your functions to instruct Numba to compile them. When a call is made to a Numba-decorated function it is compiled to machine code “just-in-time” for execution and all or part of your code can subsequently run at native machine code speed!
https://numba.readthedocs.io/en/stable/user/5minguide.html
# pip
pip install numba
# conda
conda install numba
import numpy as np
from numba import njit
# 假设想计算1w个人脸与人脸库(总量1w)的特征距离
"""
numpy实现
"""
compare_1w = np.random.rand(10000, 512)
base_10w = np.random.rand(512, 10000)
distance = np.dot(compare_1w, base_10w)
"""
numba实现
"""
# 简单的将计算步骤抽象成函数,并在函数外装饰 @njit()@njit()
def cal_dot(a, b):
return np.dot(a, b)
distance = cal_dot(compare_1w, base_10w)
"""结果:速度提升约30倍
numpy cost: 0:00:45.465711
numba cost: 0:00:01.553018
"""
njit()
,函数不支持传入dict,function
njit()
,函数不支持continue,break,try,except
for,if,else
njit()
现象:若声明变量a = [(),[],[]]*10
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(list(Tuple())<iv=None>, int64, Tuple(int64, array(int64, 1d, A), list(float32)<iv=None>))
解决办法:想办法把数据结构改成数组,或分成多个数组存放。数组需要事先声明大小
njit()String.format()
现象
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'format' of type Literal[str]([*] Calculate img_type:{} similarity, batch {}/10, progress: {}/{})
解决办法:去掉format,改用字符串拼接
njit(paralle=True)List.append()
现象:append线程不安全
File "/******/python3.9/site-packages/numba/cpython/listobj.py", line 1129, in list_to_list
assert fromty.dtype == toty.dtype
AssertionError
解决办法:移除被装饰函数中的append(),改用index访问,预先声明list/np.ndarray的大小
njit(paralle=True)
现象:
/*******/python3.9/site-packages/numba/np/ufunc/parallel.py:365: NumbaWarning: The TBB threading layer requires TBB version 2019.5 or later i.e., TBB_INTERFACE_VERSION >= 11005. Found TBB_INTERFACE_VERSION = 9107.
The TBB threading layer is disabled. warnings.warn(problem)
解决办法:https://github.com/numba/numba/issues/6350
conda install tbb
# or
pip install --upgrade tbb
import numpy as np
from numba import jit, njit
from datetime import datetime
# test1
@jit(nopython=True)
def cal(a, b):
c = np.dot(b, a)
d = np.max(c)
@njit(parallel=True)
def gen(n):
a = np.random.rand(512, 10**n)
b = np.random.rand(100,512)
return a, b
def numba_cal(a, b):
start = datetime.now()
cal(a, b)
end = datetime.now()
return end - start
def np_cal(a, b):
start = datetime.now()
np.dot(b, a)
end = datetime.now()
return end - start
for i in range(8):
a, b = gen(i)
delta = numba_cal(a, b)
print("scale {}: numba cost: {}".format(10**i, delta))
delta = np_cal(a, b)
print("scale {}: np cost: {}".format(10**i, delta))