高级扩展API
该扩展API通过 numba.extending
模块公开。
实现函数
@overload
装饰器允许您实现任意函数,以便在 nopython 模式函数中使用。使用 @overload
装饰的函数在编译时被调用,参数是函数的运行时参数的 类型。它应该返回一个可调用对象,表示给定类型函数的 实现。返回的实现由 Numba 编译,就像它是一个用 @jit
装饰的普通函数一样。@jit
的额外选项可以通过字典形式使用 jit_options
参数传递。
例如,假设 Numba 尚不支持对元组的 len()
函数。下面是如何使用 @overload
实现它
from numba import types
from numba.extending import overload
@overload(len)
def tuple_len(seq):
if isinstance(seq, types.BaseTuple):
n = len(seq)
def len_impl(seq):
return n
return len_impl
您可能想知道,如果 len()
函数被调用时传入的不是元组,会发生什么?如果一个用 @overload
装饰的函数没有返回任何内容(即返回 None),则会尝试其他定义,直到有一个成功。因此,多个库可以为不同类型重载 len()
函数而不会相互冲突。
实现方法
@overload_method
装饰器类似地允许在 Numba 已知的类型上实现方法。
- numba.core.extending.overload_method(typ, attr, **kwargs)
一个装饰器,用于将装饰的函数标记为类型化,并在 nopython 模式下为给定的 Numba 类型实现方法 attr。
kwargs 会传递给底层的 @overload 调用。
这是一个为数组类型实现 .take() 的示例
@overload_method(types.Array, 'take') def array_take(arr, indices): if isinstance(indices, types.Array): def take_impl(arr, indices): n = indices.shape[0] res = np.empty(n, arr.dtype) for i in range(n): res[i] = arr[indices[i]] return res return take_impl
实现类方法
@overload_classmethod
装饰器类似地允许在 Numba 已知的类型上实现类方法。
- numba.core.extending.overload_classmethod(typ, attr, **kwargs)
一个装饰器,用于将装饰的函数标记为类型化,并在 nopython 模式下为给定的 Numba 类型实现类方法 attr。
类似于
overload_method
。这是一个在 Array 类型上实现类方法以调用
np.arange()
的示例@overload_classmethod(types.Array, "make") def ov_make(cls, nitems): def impl(cls, nitems): return np.arange(nitems) return impl
以上代码将允许以下内容在 JIT 编译代码中工作
@njit def foo(n): return types.Array.make(n)
实现属性
@overload_attribute
装饰器允许在类型上实现数据属性(或特性)。只能读取属性;可写属性仅通过 低级 API 支持。
以下示例实现了 Numpy 数组上的 nbytes
属性
@overload_attribute(types.Array, 'nbytes')
def array_nbytes(arr):
def get(arr):
return arr.size * arr.itemsize
return get
导入 Cython 函数
函数 get_cython_function_address
获取 Cython 扩展模块中 C 函数的地址。该地址可以通过 ctypes.CFUNCTYPE()
回调来访问 C 函数,从而允许在 Numba JIT 编译函数中使用 C 函数。例如,假设您有文件 foo.pyx
from libc.math cimport exp
cdef api double myexp(double x):
return exp(x)
您可以通过以下方式从 Numba 访问 myexp
import ctypes
from numba.extending import get_cython_function_address
addr = get_cython_function_address("foo", "myexp")
functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
myexp = functype(addr)
函数 myexp
现在可以在 JIT 编译函数中使用,例如
@njit
def double_myexp(x):
return 2*myexp(x)
一个注意事项是,如果您的函数使用了 Cython 的融合类型(fused types),那么函数的名称将被修改(mangled)。要找出您函数的修改后的名称,您可以检查扩展模块的 __pyx_capi__
属性。
实现内联函数 (intrinsics)
@intrinsic
装饰器用于将函数 func 标记为类型化,并使用 llvmlite IRBuilder API 在 nopython
模式下实现该函数。这是为专家用户提供的逃生舱口,用于构建将内联到调用者中的自定义 LLVM IR,没有安全网!
func 的第一个参数是类型上下文。其余参数对应于被装饰函数的参数类型。这些参数也用作被装饰函数的形参。如果 func 的签名是 foo(typing_context, arg0, arg1)
,则被装饰的函数将具有签名 foo(arg0, arg1)
。
func 的返回值应该是一个包含预期类型签名和将传递给 lower_builtin()
的代码生成函数的两元组。对于不支持的操作,返回 None
。
这是一个将任意整数转换为字节指针的示例
from numba import types
from numba.extending import intrinsic
@intrinsic
def cast_int_to_byte_ptr(typingctx, src):
# check for accepted types
if isinstance(src, types.Integer):
# create the expected type signature
result_type = types.CPointer(types.uint8)
sig = result_type(types.uintp)
# defines the custom code generation
def codegen(context, builder, signature, args):
# llvm IRBuilder code here
[src] = args
rtype = signature.return_type
llrtype = context.get_value_type(rtype)
return builder.inttoptr(src, llrtype)
return sig, codegen
它可以按如下方式使用
from numba import njit
@njit('void(int64)')
def foo(x):
y = cast_int_to_byte_ptr(x)
foo.inspect_types()
并且 .inspect_types()
的输出展示了该转换(注意 uint8*
)
def foo(x):
# x = arg(0, name=x) :: int64
# $0.1 = global(cast_int_to_byte_ptr: <intrinsic cast_int_to_byte_ptr>) :: Function(<intrinsic cast_int_to_byte_ptr>)
# $0.3 = call $0.1(x, func=$0.1, args=[Var(x, check_intrin.py (24))], kws=(), vararg=None) :: (uint64,) -> uint8*
# del x
# del $0.1
# y = $0.3 :: uint8*
# del y
# del $0.3
# $const0.4 = const(NoneType, None) :: none
# $0.5 = cast(value=$const0.4) :: none
# del $const0.4
# return $0.5
y = cast_int_to_byte_ptr(x)
实现可变结构
警告
这是一个实验性功能,API 可能会在不发出警告的情况下更改。
numba.experimental.structref
模块提供了定义可变传引用结构(即 StructRef
)的工具。以下示例演示了如何定义一个基本的、可变的结构
定义 StructRef
numba/tests/doc_examples/test_structref_usage.py
1import numpy as np
2
3from numba import njit
4from numba.core import types
5from numba.experimental import structref
6
7from numba.tests.support import skip_unless_scipy
8
9
10# Define a StructRef.
11# `structref.register` associates the type with the default data model.
12# This will also install getters and setters to the fields of
13# the StructRef.
14@structref.register
15class MyStructType(types.StructRef):
16 def preprocess_fields(self, fields):
17 # This method is called by the type constructor for additional
18 # preprocessing on the fields.
19 # Here, we don't want the struct to take Literal types.
20 return tuple((name, types.unliteral(typ)) for name, typ in fields)
21
22
23# Define a Python type that can be use as a proxy to the StructRef
24# allocated inside Numba. Users can construct the StructRef via
25# the constructor for this type in python code and jit-code.
26class MyStruct(structref.StructRefProxy):
27 def __new__(cls, name, vector):
28 # Overriding the __new__ method is optional, doing so
29 # allows Python code to use keyword arguments,
30 # or add other customized behavior.
31 # The default __new__ takes `*args`.
32 # IMPORTANT: Users should not override __init__.
33 return structref.StructRefProxy.__new__(cls, name, vector)
34
35 # By default, the proxy type does not reflect the attributes or
36 # methods to the Python side. It is up to users to define
37 # these. (This may be automated in the future.)
38
39 @property
40 def name(self):
41 # To access a field, we can define a function that simply
42 # return the field in jit-code.
43 # The definition of MyStruct_get_name is shown later.
44 return MyStruct_get_name(self)
45
46 @property
47 def vector(self):
48 # The definition of MyStruct_get_vector is shown later.
49 return MyStruct_get_vector(self)
50
51
52@njit
53def MyStruct_get_name(self):
54 # In jit-code, the StructRef's attribute is exposed via
55 # structref.register
56 return self.name
57
58
59@njit
60def MyStruct_get_vector(self):
61 return self.vector
62
63
64# This associates the proxy with MyStructType for the given set of
65# fields. Notice how we are not constraining the type of each field.
66# Field types remain generic.
67structref.define_proxy(MyStruct, MyStructType, ["name", "vector"])
以下演示了如何使用上述可变结构定义
numba/tests/doc_examples/test_structref_usage.py
的 test_type_definition
1# Let's test our new StructRef.
2
3# Define one in Python
4alice = MyStruct("Alice", vector=np.random.random(3))
5
6# Define one in jit-code
7@njit
8def make_bob():
9 bob = MyStruct("unnamed", vector=np.zeros(3))
10 # Mutate the attributes
11 bob.name = "Bob"
12 bob.vector = np.random.random(3)
13 return bob
14
15bob = make_bob()
16
17# Out: Alice: [0.5488135 0.71518937 0.60276338]
18print(f"{alice.name}: {alice.vector}")
19# Out: Bob: [0.88325739 0.73527629 0.87746707]
20print(f"{bob.name}: {bob.vector}")
21
22# Define a jit function to operate on the structs.
23@njit
24def distance(a, b):
25 return np.linalg.norm(a.vector - b.vector)
26
27# Out: 0.4332647200356598
28print(distance(alice, bob))
在 StructRef 上定义方法
如前几节所示,方法和属性可以使用 @overload_*
进行附加。
以下演示了如何使用 @overload_method
为 MyStructType
的实例插入方法
numba/tests/doc_examples/test_structref_usage.py
的 test_overload_method
1from numba.core.extending import overload_method
2from numba.core.errors import TypingError
3
4# Use @overload_method to add a method for
5# MyStructType.distance(other)
6# where *other* is an instance of MyStructType.
7@overload_method(MyStructType, "distance")
8def ol_distance(self, other):
9 # Guard that *other* is an instance of MyStructType
10 if not isinstance(other, MyStructType):
11 raise TypingError(
12 f"*other* must be a {MyStructType}; got {other}"
13 )
14
15 def impl(self, other):
16 return np.linalg.norm(self.vector - other.vector)
17
18 return impl
19
20# Test
21@njit
22def test():
23 alice = MyStruct("Alice", vector=np.random.random(3))
24 bob = MyStruct("Bob", vector=np.random.random(3))
25 # Use the method
26 return alice.distance(bob)
numba.experimental.structref
API 参考
定义可变结构的实用工具。
可变结构是按引用传递的;因此,structref(对结构的引用)。
- class numba.experimental.structref.StructRefProxy(*args)
一个 PyObject 代理,指向 Numba 分配的 structref 数据结构。
注意
子类不应定义
__init__
。子类可以重写
__new__
。
- numba.experimental.structref.define_attributes(struct_typeclass)
在 struct_typeclass 上定义属性。
在 JIT 代码中定义了 setter 和 getter。
这在 register() 中直接调用。
- numba.experimental.structref.define_boxing(struct_type, obj_class)
定义 struct_type 到 obj_class 的装箱(boxing)和拆箱(unboxing)逻辑。
定义了装箱和拆箱。
装箱将 struct_type 的实例转换为 obj_class 的 PyObject
拆箱将 obj_class 的实例转换为 JIT 代码中 struct_type 的实例。
当用户不希望定义任何构造函数时,直接使用此方法而不是 define_proxy()。
- numba.experimental.structref.define_constructor(py_class, struct_typeclass, fields)
使用 Python 类型 py_class 和所需的 fields 为 struct_typeclass 定义 JIT 代码构造函数。
如果用户不希望定义装箱逻辑,请使用此方法而不是 define_proxy()。
- numba.experimental.structref.define_proxy(py_class, struct_typeclass, fields)
为 structref 定义一个 PyObject 代理。
这使得 py_class 成为一个有效的构造函数,用于创建包含由 fields 定义的成员的 struct_typeclass 实例。
- 参数
- py_class类型
用于构造 struct_typeclass 实例的 Python 类。
- struct_typeclassnumba.core.types.Type
要绑定的 structref 类型类。
- fieldsSequence[str]
字段名称序列。
- 返回
- None
- numba.experimental.structref.register(struct_type)
注册一个 numba.core.types.StructRef 以在 JIT 代码中使用。
这定义了降低 struct_type 实例的数据模型。这定义了 struct_type 实例的属性访问器和修改器。
- 参数
- struct_type类型
numba.core.types.StructRef 的子类。
- 返回
- struct_type类型
返回输入参数,因此这可以作为一个装饰器。
示例
class MyStruct(numba.core.types.StructRef): ... # the simplest subclass can be empty numba.experimental.structref.register(MyStruct)