一则小记,numba库并不支持NumPy的int
类型,需要指定整型的位宽,使用int32
和int64
都没有问题。测试代码如下
本机系统Ubuntu 18.04, Python 3.6.8, numba 0.46.0, NumPy 1.16.4
import numba
import numpy as np
import sys
@numba.jit(nopython=True)
def test_zeros(H, W):
# array = np.zeros((H, W, 3), dtype=np.int) # Will cause jit error.
array = np.zeros((H, W, 3), dtype=np.int64) # OK.
array = np.zeros((H, W, 3), dtype=np.int32) # OK.
def main():
test_zeros(100, 200)
return 0
if __name__ == "__main__":
sys.exit(main())
上述代码报出的错误如下
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function() with argument(s) of type(s): ((int64, int64, Literal[int](3)), dtype=Function())
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function()
[2] During: typing of call at TestNumba.py (8)
File "TestNumba.py", line 8:
def test_zeros(H, W):
array = np.zeros((H, W, 3), dtype=np.int) # Will cause jit error.