numba使用踩坑总结

由于没有完全理解numba就直接使用了,所以犯了一些使用时的错误。

  1. 尽量使用numpy的数据类型来写代码,numba对numpy的支持最好;但是并不是所有的numpy函数都被支持,比如我用到的np.clip, np.pad等函数都不支持,通过下面网址查看到底支持哪些numpy函数:http://numba.pydata.org/numba-doc/latest/reference/numpysupported.html。 遇到无法支持的函数时,两个选择,一个是重新手写该函数;另一个则是选择不在jit加速范围内调用该函数。
  2. np.zeros(shape, type)函数的调用中犯了一个错误,平时习惯性地会使用一个list作为shape参数,如[10, 10],平时正常使用numpy的时候也没问题,但是使用numba加速时却遇到了编译问题:
    Compilation is falling back to object mode WITHOUT looplifting enabled because Function "xxx" failed type inference due to: Invalid use of type(CPUDispatcher()) with parameters (int64, int64, int64, array(float64, 4d, C))
    During: resolving callee type: type(CPUDispatcher())
    
    numba的编译告警确实有点不太直观,从告警上很难定位到具体的问题出在哪儿,而且往往一个问题会引发出多处的告警。具体在定位的时候我喜欢用简化排除法,构造很简单的例子来排除问题,比如先把函数体变成空的,然后一点点加上内容,看问题到底出在哪里。
  3. 再来看一个类似的例子:https://github.com/numba/numba/issues/4650
    在numba的官方那里的一个issue,跟我上面的报错很类似。
    from numba import njit
    import numpy as np
    
    @njit
    def _get_most_similar(query_ftrs: np.ndarray, all_images_ftrs: np.ndarray) -> np.ndarray:
        products = np.empty(all_images_ftrs.shape[0], dtype=query_ftrs.dtype)
        for i in range(len(all_images_ftrs)):
            ftrs = all_images_ftrs[i]
            products[i] = np.dot(query_ftrs, ftrs)
    
    query_ftrs = np.zeros((1, 2048), dtype="float32")
    all_images_ftrs = np.zeros((18536, 2048), dtype="float32")
    
    _get_most_similar.py_func(query_ftrs, all_images_ftrs) # numpy is fine
    
    _get_most_similar(query_ftrs, all_images_ftrs) # numba is not
    
    报错信息如下:
    The error:
    numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    Invalid use of Function() with argument(s) of type(s): (array(float64, 1d, C), int64, array(float32, 1d, C))
    
    官方人员给出了如下的定位过程:
    numba使用踩坑总结_第1张图片
    总结下就是numba不能混用一个元素的list和单个scaler,但是numpy是可以的。这跟我那个问题的原因也很像,充分说明了numba对数据类型的要求很严格,推理地很严谨,不具有Numpy的兼容性。应该是tuple就别用list,应该是scaler也别用tuple(1,)。
  4. 再看一个类型问题:http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile
    是个官网文档上的例子。

你可能感兴趣的:(python,numpy,python,numba)