使用 @overload
的指南
如 高级扩展 API 中所述,您可以使用 @overload
装饰器来创建可在 nopython 模式函数中使用的函数的 Numba 实现。一个常见的用例是重新实现 NumPy 函数,以便它们可以在由 @jit
装饰的代码中调用。本节讨论了何时以及如何使用 @overload
装饰器,以及向 Numba 代码库贡献此类函数可能涉及的内容。这应能帮助您在需要使用 @overload
装饰器或尝试向 Numba 本身贡献新函数时入门。
@overload
装饰器及其变体在您拥有无法控制的第三方库,并希望为该库中的特定函数提供 Numba 兼容的实现时非常有用。
具体示例
假设您正在开发一种最小化算法,该算法利用 scipy.linalg.norm
来查找不同的向量范数和矩阵的 Frobenius 范数。您知道只会涉及整数和实数。(尽管这听起来像一个人工示例,尤其是因为 numpy.linalg.norm
的 Numba 实现已经存在,但它主要是教学目的,旨在说明如何以及何时使用 @overload
)。
骨架可能看起来像这样
def algorithm():
# setup
v = ...
while True:
# take a step
d = scipy.linalg.norm(v)
if d < tolerance:
break
现在,我们进一步假设您已经听说过 Numba,并且希望使用它来加速您的函数。然而,在添加 jit(nopython=True)
装饰器后,Numba 抱怨 scipy.linalg.norm
不受支持。通过查阅文档,您意识到使用 NumPy 实现范数可能相当容易。一个好的起点是以下模板。
# Declare that function `myfunc` is going to be overloaded (have a
# substitutable Numba implementation)
@overload(myfunc)
# Define the overload function with formal arguments
# these arguments must be matched in the inner function implementation
def jit_myfunc(arg0, arg1, arg2, ...):
# This scope is for typing, access is available to the *type* of all
# arguments. This information can be used to change the behaviour of the
# implementing function and check that the types are actually supported
# by the implementation.
print(arg0) # this will show the Numba type of arg0
# This is the definition of the function that implements the `myfunc` work.
# It does whatever algorithm is needed to implement myfunc.
def myfunc_impl(arg0, arg1, arg2, ...): # match arguments to jit_myfunc
# < Implementation goes here >
return # whatever needs to be returned by the algorithm
# return the implementation
return myfunc_impl
经过一番推敲和修改,您得到了以下代码
import numpy as np
from numba import njit, types
from numba.extending import overload, register_jitable
from numba.core.errors import TypingError
import scipy.linalg
@register_jitable
def _oneD_norm_2(a):
# re-usable implementation of the 2-norm
val = np.abs(a)
return np.sqrt(np.sum(val * val))
@overload(scipy.linalg.norm)
def jit_norm(a, ord=None):
if isinstance(ord, types.Optional):
ord = ord.type
# Reject non integer, floating-point or None types for ord
if not isinstance(ord, (types.Integer, types.Float, types.NoneType)):
raise TypingError("'ord' must be either integer or floating-point")
# Reject non-ndarray types
if not isinstance(a, types.Array):
raise TypingError("Only accepts NumPy ndarray")
# Reject ndarrays with non integer or floating-point dtype
if not isinstance(a.dtype, (types.Integer, types.Float)):
raise TypingError("Only integer and floating point types accepted")
# Reject ndarrays with unsupported dimensionality
if not (0 <= a.ndim <= 2):
raise TypingError('3D and beyond are not allowed')
# Implementation for scalars/0d-arrays
elif a.ndim == 0:
return a.item()
# Implementation for vectors
elif a.ndim == 1:
def _oneD_norm_x(a, ord=None):
if ord == 2 or ord is None:
return _oneD_norm_2(a)
elif ord == np.inf:
return np.max(np.abs(a))
elif ord == -np.inf:
return np.min(np.abs(a))
elif ord == 0:
return np.sum(a != 0)
elif ord == 1:
return np.sum(np.abs(a))
else:
return np.sum(np.abs(a)**ord)**(1. / ord)
return _oneD_norm_x
# Implementation for matrices
elif a.ndim == 2:
def _two_D_norm_2(a, ord=None):
return _oneD_norm_2(a.ravel())
return _two_D_norm_2
if __name__ == "__main__":
@njit
def use(a, ord=None):
# simple test function to check that the overload works
return scipy.linalg.norm(a, ord)
# spot check for vectors
a = np.arange(10)
print(use(a))
print(scipy.linalg.norm(a))
# spot check for matrices
b = np.arange(9).reshape((3, 3))
print(use(b))
print(scipy.linalg.norm(b))
如您所见,当前的实现只支持您目前所需的功能
仅支持整数和浮点类型
所有向量范数
仅支持矩阵的 Frobenius 范数
使用
@register_jitable
在向量和矩阵实现之间共享代码。范数使用 NumPy 语法实现。(这是可能的,因为 Numba 对 NumPy 非常了解,并且支持许多函数。)
那么这里实际发生了什么? overload
装饰器为 scipy.linalg.norm
注册了一个合适的实现,以防在正在进行 JIT 编译的代码中遇到对它的调用,例如当您使用 @jit(nopython=True)
装饰您的 algorithm
函数时。在这种情况下,函数 jit_norm
将使用当前遇到的类型被调用,然后返回向量情况下的 _oneD_norm_x
或矩阵情况下的 _two_D_norm_2
。
您可以在此处下载示例代码:mynorm.py
为 NumPy 函数实现 @overload
Numba 通过提供与 @jit
兼容的 NumPy 函数重新实现来支持 NumPy。在这种情况下,@overload
是编写此类实现的一个非常方便的选择,但是还有一些额外的事项需要注意。
Numba 实现应在接受的类型、参数、引发的异常以及算法复杂度(Big-O / Landau 阶)方面尽可能接近 NumPy 实现。
在实现支持的参数类型时,请记住,由于鸭子类型(duck typing),NumPy 倾向于接受 NumPy 数组之外的多种参数类型,例如标量、列表、元组、集合、迭代器、生成器等。您需要在类型推断期间以及后续的测试中考虑到这一点。
NumPy 函数可能返回标量、数组或与其某个输入匹配的数据结构,您需要注意类型统一问题并分派到适当的实现。例如,
np.corrcoef
可能会根据其输入返回数组或标量。
如果您正在实现一个新函数,应始终更新文档。源文件位于
docs/source/reference/numpysupported.rst
。务必提及您的实现存在的任何限制,例如不支持axis
关键字。在为功能本身编写测试时,有益的做法是包括对非有限值、具有不同形状和布局的数组、复杂输入、标量输入以及未记录支持的类型输入(例如,NumPy 文档说明需要浮点数或整数输入的函数,如果提供布尔值或复数输入,也可能“正常工作”)的处理。
在编写异常测试时,例如在向
numba/tests/test_np_functions.py
添加测试时,您可能会遇到以下错误消息====================================================================== FAIL: test_foo (numba.tests.test_np_functions.TestNPFunctions) ---------------------------------------------------------------------- Traceback (most recent call last): File "<path>/numba/numba/tests/support.py", line 645, in tearDown self.memory_leak_teardown() File "<path>/numba/numba/tests/support.py", line 619, in memory_leak_teardown self.assert_no_memory_leak() File "<path>/numba/numba/tests/support.py", line 628, in assert_no_memory_leak self.assertEqual(total_alloc, total_free) AssertionError: 36 != 35
发生这种情况是因为从 JIT 编译代码中引发异常会导致引用泄漏。理想情况下,您会将所有异常测试放在一个单独的测试方法中,然后在每个测试中添加对
self.disable_leak_check()
的调用以禁用泄漏检查(继承自numba.tests.support.TestCase
以使其可用)。对于 NumPy 中可用的许多函数,在 NumPy
ndarray
类型上定义了相应的方法。例如,函数repeat
作为 NumPy 模块级函数和ndarray
类上的成员函数可用。import numpy as np a = np.arange(10) # function np.repeat(a, 10) # method a.repeat(10)
一旦您编写了函数实现,就可以轻松使用
@overload_method
并重用它。只需确保 NumPy 的函数/方法实现没有分歧。作为一个例子,
repeat
函数/方法@extending.overload_method(types.Array, 'repeat') def array_repeat(a, repeats): def array_repeat_impl(a, repeat): # np.repeat has already been overloaded return np.repeat(a, repeat) return array_repeat_impl
如果您需要创建辅助函数,例如为了重用一个小的实用函数,或者为了可读性而将实现拆分到多个函数中,您可以使用
@register_jitable
装饰器。这将使这些函数在您的@jit
和@overload
装饰函数中可用。Numba 的持续集成(CI)设置测试了各种 NumPy 版本,您有时会收到关于某个旧 NumPy 版本行为变化的警报。如果您能在 NumPy 变更日志/仓库中找到支持证据,那么您需要决定是创建分支并尝试在不同版本间复制逻辑,还是使用版本门控(并在文档中附带相关说明)来声明 Numba 从某个特定版本开始复制 NumPy。
您可以查阅 Numba 源代码以获取灵感,许多重载的 NumPy 函数和方法都位于
numba/targets/arrayobj.py
中。下面,您将找到一个实现列表,这些实现在接受的类型和测试覆盖率方面都做得很好。np.repeat