高级扩展API

该扩展API通过 numba.extending 模块公开。

实现函数

@overload 装饰器允许您实现任意函数,以便在 nopython 模式函数中使用。使用 @overload 装饰的函数在编译时被调用,参数是函数的运行时参数的 类型。它应该返回一个可调用对象,表示给定类型函数的 实现。返回的实现由 Numba 编译,就像它是一个用 @jit 装饰的普通函数一样。@jit 的额外选项可以通过字典形式使用 jit_options 参数传递。

例如,假设 Numba 尚不支持对元组的 len() 函数。下面是如何使用 @overload 实现它

from numba import types
from numba.extending import overload

@overload(len)
def tuple_len(seq):
   if isinstance(seq, types.BaseTuple):
       n = len(seq)
       def len_impl(seq):
           return n
       return len_impl

您可能想知道,如果 len() 函数被调用时传入的不是元组,会发生什么?如果一个用 @overload 装饰的函数没有返回任何内容(即返回 None),则会尝试其他定义,直到有一个成功。因此,多个库可以为不同类型重载 len() 函数而不会相互冲突。

实现方法

@overload_method 装饰器类似地允许在 Numba 已知的类型上实现方法。

numba.core.extending.overload_method(typ, attr, **kwargs)

一个装饰器,用于将装饰的函数标记为类型化,并在 nopython 模式下为给定的 Numba 类型实现方法 attr

kwargs 会传递给底层的 @overload 调用。

这是一个为数组类型实现 .take() 的示例

@overload_method(types.Array, 'take')
def array_take(arr, indices):
    if isinstance(indices, types.Array):
        def take_impl(arr, indices):
            n = indices.shape[0]
            res = np.empty(n, arr.dtype)
            for i in range(n):
                res[i] = arr[indices[i]]
            return res
        return take_impl

实现类方法

@overload_classmethod 装饰器类似地允许在 Numba 已知的类型上实现类方法。

numba.core.extending.overload_classmethod(typ, attr, **kwargs)

一个装饰器,用于将装饰的函数标记为类型化,并在 nopython 模式下为给定的 Numba 类型实现类方法 attr

类似于 overload_method

这是一个在 Array 类型上实现类方法以调用 np.arange() 的示例

@overload_classmethod(types.Array, "make")
def ov_make(cls, nitems):
    def impl(cls, nitems):
        return np.arange(nitems)
    return impl

以上代码将允许以下内容在 JIT 编译代码中工作

@njit
def foo(n):
    return types.Array.make(n)

实现属性

@overload_attribute 装饰器允许在类型上实现数据属性(或特性)。只能读取属性;可写属性仅通过 低级 API 支持。

以下示例实现了 Numpy 数组上的 nbytes 属性

@overload_attribute(types.Array, 'nbytes')
def array_nbytes(arr):
   def get(arr):
       return arr.size * arr.itemsize
   return get

导入 Cython 函数

函数 get_cython_function_address 获取 Cython 扩展模块中 C 函数的地址。该地址可以通过 ctypes.CFUNCTYPE() 回调来访问 C 函数,从而允许在 Numba JIT 编译函数中使用 C 函数。例如,假设您有文件 foo.pyx

from libc.math cimport exp

cdef api double myexp(double x):
    return exp(x)

您可以通过以下方式从 Numba 访问 myexp

import ctypes
from numba.extending import get_cython_function_address

addr = get_cython_function_address("foo", "myexp")
functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
myexp = functype(addr)

函数 myexp 现在可以在 JIT 编译函数中使用,例如

@njit
def double_myexp(x):
    return 2*myexp(x)

一个注意事项是,如果您的函数使用了 Cython 的融合类型(fused types),那么函数的名称将被修改(mangled)。要找出您函数的修改后的名称,您可以检查扩展模块的 __pyx_capi__ 属性。

实现内联函数 (intrinsics)

@intrinsic 装饰器用于将函数 func 标记为类型化,并使用 llvmlite IRBuilder APInopython 模式下实现该函数。这是为专家用户提供的逃生舱口,用于构建将内联到调用者中的自定义 LLVM IR,没有安全网!

func 的第一个参数是类型上下文。其余参数对应于被装饰函数的参数类型。这些参数也用作被装饰函数的形参。如果 func 的签名是 foo(typing_context, arg0, arg1),则被装饰的函数将具有签名 foo(arg0, arg1)

func 的返回值应该是一个包含预期类型签名和将传递给 lower_builtin() 的代码生成函数的两元组。对于不支持的操作,返回 None

这是一个将任意整数转换为字节指针的示例

from numba import types
from numba.extending import intrinsic

@intrinsic
def cast_int_to_byte_ptr(typingctx, src):
    # check for accepted types
    if isinstance(src, types.Integer):
        # create the expected type signature
        result_type = types.CPointer(types.uint8)
        sig = result_type(types.uintp)
        # defines the custom code generation
        def codegen(context, builder, signature, args):
            # llvm IRBuilder code here
            [src] = args
            rtype = signature.return_type
            llrtype = context.get_value_type(rtype)
            return builder.inttoptr(src, llrtype)
        return sig, codegen

它可以按如下方式使用

from numba import njit

@njit('void(int64)')
def foo(x):
    y = cast_int_to_byte_ptr(x)

foo.inspect_types()

并且 .inspect_types() 的输出展示了该转换(注意 uint8*

def foo(x):

    #   x = arg(0, name=x)  :: int64
    #   $0.1 = global(cast_int_to_byte_ptr: <intrinsic cast_int_to_byte_ptr>)  :: Function(<intrinsic cast_int_to_byte_ptr>)
    #   $0.3 = call $0.1(x, func=$0.1, args=[Var(x, check_intrin.py (24))], kws=(), vararg=None)  :: (uint64,) -> uint8*
    #   del x
    #   del $0.1
    #   y = $0.3  :: uint8*
    #   del y
    #   del $0.3
    #   $const0.4 = const(NoneType, None)  :: none
    #   $0.5 = cast(value=$const0.4)  :: none
    #   del $const0.4
    #   return $0.5

    y = cast_int_to_byte_ptr(x)

实现可变结构

警告

这是一个实验性功能,API 可能会在不发出警告的情况下更改。

numba.experimental.structref 模块提供了定义可变传引用结构(即 StructRef)的工具。以下示例演示了如何定义一个基本的、可变的结构

定义 StructRef

来自 numba/tests/doc_examples/test_structref_usage.py
 1import numpy as np
 2
 3from numba import njit
 4from numba.core import types
 5from numba.experimental import structref
 6
 7from numba.tests.support import skip_unless_scipy
 8
 9
10# Define a StructRef.
11# `structref.register` associates the type with the default data model.
12# This will also install getters and setters to the fields of
13# the StructRef.
14@structref.register
15class MyStructType(types.StructRef):
16    def preprocess_fields(self, fields):
17        # This method is called by the type constructor for additional
18        # preprocessing on the fields.
19        # Here, we don't want the struct to take Literal types.
20        return tuple((name, types.unliteral(typ)) for name, typ in fields)
21
22
23# Define a Python type that can be use as a proxy to the StructRef
24# allocated inside Numba. Users can construct the StructRef via
25# the constructor for this type in python code and jit-code.
26class MyStruct(structref.StructRefProxy):
27    def __new__(cls, name, vector):
28        # Overriding the __new__ method is optional, doing so
29        # allows Python code to use keyword arguments,
30        # or add other customized behavior.
31        # The default __new__ takes `*args`.
32        # IMPORTANT: Users should not override __init__.
33        return structref.StructRefProxy.__new__(cls, name, vector)
34
35    # By default, the proxy type does not reflect the attributes or
36    # methods to the Python side. It is up to users to define
37    # these. (This may be automated in the future.)
38
39    @property
40    def name(self):
41        # To access a field, we can define a function that simply
42        # return the field in jit-code.
43        # The definition of MyStruct_get_name is shown later.
44        return MyStruct_get_name(self)
45
46    @property
47    def vector(self):
48        # The definition of MyStruct_get_vector is shown later.
49        return MyStruct_get_vector(self)
50
51
52@njit
53def MyStruct_get_name(self):
54    # In jit-code, the StructRef's attribute is exposed via
55    # structref.register
56    return self.name
57
58
59@njit
60def MyStruct_get_vector(self):
61    return self.vector
62
63
64# This associates the proxy with MyStructType for the given set of
65# fields. Notice how we are not constraining the type of each field.
66# Field types remain generic.
67structref.define_proxy(MyStruct, MyStructType, ["name", "vector"])

以下演示了如何使用上述可变结构定义

来自 numba/tests/doc_examples/test_structref_usage.pytest_type_definition
 1# Let's test our new StructRef.
 2
 3# Define one in Python
 4alice = MyStruct("Alice", vector=np.random.random(3))
 5
 6# Define one in jit-code
 7@njit
 8def make_bob():
 9    bob = MyStruct("unnamed", vector=np.zeros(3))
10    # Mutate the attributes
11    bob.name = "Bob"
12    bob.vector = np.random.random(3)
13    return bob
14
15bob = make_bob()
16
17# Out: Alice: [0.5488135  0.71518937 0.60276338]
18print(f"{alice.name}: {alice.vector}")
19# Out: Bob: [0.88325739 0.73527629 0.87746707]
20print(f"{bob.name}: {bob.vector}")
21
22# Define a jit function to operate on the structs.
23@njit
24def distance(a, b):
25    return np.linalg.norm(a.vector - b.vector)
26
27# Out: 0.4332647200356598
28print(distance(alice, bob))

在 StructRef 上定义方法

如前几节所示,方法和属性可以使用 @overload_* 进行附加。

以下演示了如何使用 @overload_methodMyStructType 的实例插入方法

来自 numba/tests/doc_examples/test_structref_usage.pytest_overload_method
 1from numba.core.extending import overload_method
 2from numba.core.errors import TypingError
 3
 4# Use @overload_method to add a method for
 5# MyStructType.distance(other)
 6# where *other* is an instance of MyStructType.
 7@overload_method(MyStructType, "distance")
 8def ol_distance(self, other):
 9    # Guard that *other* is an instance of MyStructType
10    if not isinstance(other, MyStructType):
11        raise TypingError(
12            f"*other* must be a {MyStructType}; got {other}"
13        )
14
15    def impl(self, other):
16        return np.linalg.norm(self.vector - other.vector)
17
18    return impl
19
20# Test
21@njit
22def test():
23    alice = MyStruct("Alice", vector=np.random.random(3))
24    bob = MyStruct("Bob", vector=np.random.random(3))
25    # Use the method
26    return alice.distance(bob)

numba.experimental.structref API 参考

定义可变结构的实用工具。

可变结构是按引用传递的;因此,structref(对结构的引用)。

class numba.experimental.structref.StructRefProxy(*args)

一个 PyObject 代理,指向 Numba 分配的 structref 数据结构。

注意

  • 子类不应定义 __init__

  • 子类可以重写 __new__

numba.experimental.structref.define_attributes(struct_typeclass)

struct_typeclass 上定义属性。

在 JIT 代码中定义了 setter 和 getter。

这在 register() 中直接调用。

numba.experimental.structref.define_boxing(struct_type, obj_class)

定义 struct_typeobj_class 的装箱(boxing)和拆箱(unboxing)逻辑。

定义了装箱和拆箱。

  • 装箱将 struct_type 的实例转换为 obj_class 的 PyObject

  • 拆箱将 obj_class 的实例转换为 JIT 代码中 struct_type 的实例。

当用户不希望定义任何构造函数时,直接使用此方法而不是 define_proxy()

numba.experimental.structref.define_constructor(py_class, struct_typeclass, fields)

使用 Python 类型 py_class 和所需的 fieldsstruct_typeclass 定义 JIT 代码构造函数。

如果用户不希望定义装箱逻辑,请使用此方法而不是 define_proxy()

numba.experimental.structref.define_proxy(py_class, struct_typeclass, fields)

为 structref 定义一个 PyObject 代理。

这使得 py_class 成为一个有效的构造函数,用于创建包含由 fields 定义的成员的 struct_typeclass 实例。

参数
py_class类型

用于构造 struct_typeclass 实例的 Python 类。

struct_typeclassnumba.core.types.Type

要绑定的 structref 类型类。

fieldsSequence[str]

字段名称序列。

返回
None
numba.experimental.structref.register(struct_type)

注册一个 numba.core.types.StructRef 以在 JIT 代码中使用。

这定义了降低 struct_type 实例的数据模型。这定义了 struct_type 实例的属性访问器和修改器。

参数
struct_type类型

numba.core.types.StructRef 的子类。

返回
struct_type类型

返回输入参数,因此这可以作为一个装饰器。

示例

class MyStruct(numba.core.types.StructRef):
    ...  # the simplest subclass can be empty

numba.experimental.structref.register(MyStruct)

判断一个函数是否已被 jit 系列装饰器包装

为此提供了以下函数。

extending.is_jitted()

如果一个函数被 Numba 的 @jit 装饰器之一包装(例如:numba.jit, numba.njit),则返回 True

此函数的目的是提供一种方法来检查函数是否已被 JIT 装饰。