使用 @overload 的指南

高级扩展 API 中所述,您可以使用 @overload 装饰器来创建可在 nopython 模式函数中使用的函数的 Numba 实现。一个常见的用例是重新实现 NumPy 函数,以便它们可以在由 @jit 装饰的代码中调用。本节讨论了何时以及如何使用 @overload 装饰器,以及向 Numba 代码库贡献此类函数可能涉及的内容。这应能帮助您在需要使用 @overload 装饰器或尝试向 Numba 本身贡献新函数时入门。

@overload 装饰器及其变体在您拥有无法控制的第三方库,并希望为该库中的特定函数提供 Numba 兼容的实现时非常有用。

具体示例

假设您正在开发一种最小化算法,该算法利用 scipy.linalg.norm 来查找不同的向量范数和矩阵的 Frobenius 范数。您知道只会涉及整数和实数。(尽管这听起来像一个人工示例,尤其是因为 numpy.linalg.norm 的 Numba 实现已经存在,但它主要是教学目的,旨在说明如何以及何时使用 @overload)。

骨架可能看起来像这样

def algorithm():
    # setup
    v = ...
    while True:
        # take a step
        d = scipy.linalg.norm(v)
        if d < tolerance:
            break

现在,我们进一步假设您已经听说过 Numba,并且希望使用它来加速您的函数。然而,在添加 jit(nopython=True) 装饰器后,Numba 抱怨 scipy.linalg.norm 不受支持。通过查阅文档,您意识到使用 NumPy 实现范数可能相当容易。一个好的起点是以下模板。

# Declare that function `myfunc` is going to be overloaded (have a
# substitutable Numba implementation)
@overload(myfunc)
# Define the overload function with formal arguments
# these arguments must be matched in the inner function implementation
def jit_myfunc(arg0, arg1, arg2, ...):
    # This scope is for typing, access is available to the *type* of all
    # arguments. This information can be used to change the behaviour of the
    # implementing function and check that the types are actually supported
    # by the implementation.

    print(arg0) # this will show the Numba type of arg0

    # This is the definition of the function that implements the `myfunc` work.
    # It does whatever algorithm is needed to implement myfunc.
    def myfunc_impl(arg0, arg1, arg2, ...): # match arguments to jit_myfunc
        # < Implementation goes here >
        return # whatever needs to be returned by the algorithm

    # return the implementation
    return myfunc_impl

经过一番推敲和修改,您得到了以下代码

import numpy as np
from numba import njit, types
from numba.extending import overload, register_jitable
from numba.core.errors import TypingError

import scipy.linalg


@register_jitable
def _oneD_norm_2(a):
    # re-usable implementation of the 2-norm
    val = np.abs(a)
    return np.sqrt(np.sum(val * val))


@overload(scipy.linalg.norm)
def jit_norm(a, ord=None):
    if isinstance(ord, types.Optional):
        ord = ord.type
    # Reject non integer, floating-point or None types for ord
    if not isinstance(ord, (types.Integer, types.Float, types.NoneType)):
        raise TypingError("'ord' must be either integer or floating-point")
    # Reject non-ndarray types
    if not isinstance(a, types.Array):
        raise TypingError("Only accepts NumPy ndarray")
    # Reject ndarrays with non integer or floating-point dtype
    if not isinstance(a.dtype, (types.Integer, types.Float)):
        raise TypingError("Only integer and floating point types accepted")
    # Reject ndarrays with unsupported dimensionality
    if not (0 <= a.ndim <= 2):
        raise TypingError('3D and beyond are not allowed')
    # Implementation for scalars/0d-arrays
    elif a.ndim == 0:
        return a.item()
    # Implementation for vectors
    elif a.ndim == 1:
        def _oneD_norm_x(a, ord=None):
            if ord == 2 or ord is None:
                return _oneD_norm_2(a)
            elif ord == np.inf:
                return np.max(np.abs(a))
            elif ord == -np.inf:
                return np.min(np.abs(a))
            elif ord == 0:
                return np.sum(a != 0)
            elif ord == 1:
                return np.sum(np.abs(a))
            else:
                return np.sum(np.abs(a)**ord)**(1. / ord)
        return _oneD_norm_x
    # Implementation for matrices
    elif a.ndim == 2:
        def _two_D_norm_2(a, ord=None):
            return _oneD_norm_2(a.ravel())
        return _two_D_norm_2


if __name__ == "__main__":
    @njit
    def use(a, ord=None):
        # simple test function to check that the overload works
        return scipy.linalg.norm(a, ord)

    # spot check for vectors
    a = np.arange(10)
    print(use(a))
    print(scipy.linalg.norm(a))

    # spot check for matrices
    b = np.arange(9).reshape((3, 3))
    print(use(b))
    print(scipy.linalg.norm(b))

如您所见,当前的实现只支持您目前所需的功能

  • 仅支持整数和浮点类型

  • 所有向量范数

  • 仅支持矩阵的 Frobenius 范数

  • 使用 @register_jitable 在向量和矩阵实现之间共享代码。

  • 范数使用 NumPy 语法实现。(这是可能的,因为 Numba 对 NumPy 非常了解,并且支持许多函数。)

那么这里实际发生了什么? overload 装饰器为 scipy.linalg.norm 注册了一个合适的实现,以防在正在进行 JIT 编译的代码中遇到对它的调用,例如当您使用 @jit(nopython=True) 装饰您的 algorithm 函数时。在这种情况下,函数 jit_norm 将使用当前遇到的类型被调用,然后返回向量情况下的 _oneD_norm_x 或矩阵情况下的 _two_D_norm_2

您可以在此处下载示例代码:mynorm.py

为 NumPy 函数实现 @overload

Numba 通过提供与 @jit 兼容的 NumPy 函数重新实现来支持 NumPy。在这种情况下,@overload 是编写此类实现的一个非常方便的选择,但是还有一些额外的事项需要注意。

  • Numba 实现应在接受的类型、参数、引发的异常以及算法复杂度(Big-O / Landau 阶)方面尽可能接近 NumPy 实现。

  • 在实现支持的参数类型时,请记住,由于鸭子类型(duck typing),NumPy 倾向于接受 NumPy 数组之外的多种参数类型,例如标量、列表、元组、集合、迭代器、生成器等。您需要在类型推断期间以及后续的测试中考虑到这一点。

  • NumPy 函数可能返回标量、数组或与其某个输入匹配的数据结构,您需要注意类型统一问题并分派到适当的实现。例如,np.corrcoef 可能会根据其输入返回数组或标量。

  • 如果您正在实现一个新函数,应始终更新文档。源文件位于 docs/source/reference/numpysupported.rst。务必提及您的实现存在的任何限制,例如不支持 axis 关键字。

  • 在为功能本身编写测试时,有益的做法是包括对非有限值、具有不同形状和布局的数组、复杂输入、标量输入以及未记录支持的类型输入(例如,NumPy 文档说明需要浮点数或整数输入的函数,如果提供布尔值或复数输入,也可能“正常工作”)的处理。

  • 在编写异常测试时,例如在向 numba/tests/test_np_functions.py 添加测试时,您可能会遇到以下错误消息

    ======================================================================
    FAIL: test_foo (numba.tests.test_np_functions.TestNPFunctions)
    ----------------------------------------------------------------------
    Traceback (most recent call last):
    File "<path>/numba/numba/tests/support.py", line 645, in tearDown
        self.memory_leak_teardown()
    File "<path>/numba/numba/tests/support.py", line 619, in memory_leak_teardown
        self.assert_no_memory_leak()
    File "<path>/numba/numba/tests/support.py", line 628, in assert_no_memory_leak
        self.assertEqual(total_alloc, total_free)
    AssertionError: 36 != 35
    

    发生这种情况是因为从 JIT 编译代码中引发异常会导致引用泄漏。理想情况下,您会将所有异常测试放在一个单独的测试方法中,然后在每个测试中添加对 self.disable_leak_check() 的调用以禁用泄漏检查(继承自 numba.tests.support.TestCase 以使其可用)。

  • 对于 NumPy 中可用的许多函数,在 NumPy ndarray 类型上定义了相应的方法。例如,函数 repeat 作为 NumPy 模块级函数和 ndarray 类上的成员函数可用。

    import numpy as np
    a = np.arange(10)
    # function
    np.repeat(a, 10)
    # method
    a.repeat(10)
    

    一旦您编写了函数实现,就可以轻松使用 @overload_method 并重用它。只需确保 NumPy 的函数/方法实现没有分歧。

    作为一个例子,repeat 函数/方法

    @extending.overload_method(types.Array, 'repeat')
    def array_repeat(a, repeats):
        def array_repeat_impl(a, repeat):
            # np.repeat has already been overloaded
            return np.repeat(a, repeat)
    
        return array_repeat_impl
    
  • 如果您需要创建辅助函数,例如为了重用一个小的实用函数,或者为了可读性而将实现拆分到多个函数中,您可以使用 @register_jitable 装饰器。这将使这些函数在您的 @jit@overload 装饰函数中可用。

  • Numba 的持续集成(CI)设置测试了各种 NumPy 版本,您有时会收到关于某个旧 NumPy 版本行为变化的警报。如果您能在 NumPy 变更日志/仓库中找到支持证据,那么您需要决定是创建分支并尝试在不同版本间复制逻辑,还是使用版本门控(并在文档中附带相关说明)来声明 Numba 从某个特定版本开始复制 NumPy。

  • 您可以查阅 Numba 源代码以获取灵感,许多重载的 NumPy 函数和方法都位于 numba/targets/arrayobj.py 中。下面,您将找到一个实现列表,这些实现在接受的类型和测试覆盖率方面都做得很好。

    • np.repeat