类型和签名

基本原理

作为一种优化编译器,Numba 需要确定每个变量的类型以生成高效的机器代码。Python 的标准类型不够精确,因此我们不得不开发自己的细粒度类型系统。

你主要会在检查 Numba 类型推断结果时遇到 Numba 类型,这通常是为了调试教育目的。但是,如果预先编译代码,则需要明确使用类型。

签名

签名指定了函数的类型。允许哪种签名类型取决于上下文(AOTJIT 编译),但签名总是包含 Numba 类型的一些表示形式,以指定函数参数的具体类型以及(如果需要)函数的返回类型。

一个函数签名的例子是字符串 "f8(i4, i4)"(或等效的 "float64(int32, int32)"),它指定了一个接受两个 32 位整数并返回一个双精度浮点数的函数。

基本类型

最基本的类型可以通过简单的表达式来表示。下面的符号指的是主 numba 模块的属性(所以如果你看到“boolean”,这意味着该符号可以作为 numba.boolean 访问)。许多类型都以规范名称和简写别名两种形式提供,遵循 NumPy 的约定。

数字

下表包含 Numba 当前定义的基本数值类型及其别名。

类型名称

缩写

注释

boolean

b1

表示为字节

uint8, byte

u1

8 位无符号字节

uint16

u2

16 位无符号整数

uint32

u4

32 位无符号整数

uint64

u8

64 位无符号整数

int8, char

i1

8 位有符号字节

int16

i2

16 位有符号整数

int32

i4

32 位有符号整数

int64

i8

64 位有符号整数

intc

C int 大小整数

uintc

C int 大小无符号整数

intp

指针大小整数

uintp

指针大小无符号整数

ssize_t

C ssize_t

size_t

C size_t

float32

f4

单精度浮点数

float64, double

f8

双精度浮点数

complex64

c8

单精度复数

complex128

c16

双精度复数

数组

声明 Array 类型的一种简单方法是根据维度数量对基本类型进行下标。例如一个一维单精度数组

>>> numba.float32[:]
array(float32, 1d, A)

或一个具有相同基础类型的三维数组

>>> numba.float32[:, :, :]
array(float32, 3d, A)

这种语法定义了没有特定布局的数组类型(生成的代码既接受非连续数组也接受连续数组),但你可以通过在索引规范的开头或结尾使用 ::1 索引来指定特定的连续性

>>> numba.float32[::1]
array(float32, 1d, C)
>>> numba.float32[:, :, ::1]
array(float32, 3d, C)
>>> numba.float32[::1, :, :]
array(float32, 3d, F)

这种类型声明风格在 Numba 编译函数中受支持,例如声明 typed.List 的类型。

from numba import njit, types, typed

@njit
def example():
    return typed.List.empty_list(types.float64[:, ::1])

请注意,此功能仅支持简单的数值类型。不支持应用于复合类型,例如记录类型。

函数

警告

将函数视为头等类型对象的功能正在开发中。

函数通常被认为是输入参数到输出值的某种转换。在 Numba JIT 编译函数中,函数也可以被视为对象,也就是说,除了可调用之外,函数还可以作为参数传递或作为返回值,或者在序列中用作项。

除了以下情况,所有 Numba JIT 编译函数和 Numba cfunc 编译函数都支持头等函数:

  • 使用非 CPU 编译器时,

  • 编译函数是 Python 生成器时,

  • 编译函数具有省略参数时,

  • 或编译函数返回可选值时。

要禁用头等函数支持,请使用 no_cfunc_wrapper=True 装饰器选项。

例如,考虑一个 Numba JIT 编译函数将用户指定的函数作为组合应用于输入参数的示例

>>> @numba.njit
... def composition(funcs, x):
...     r = x
...     for f in funcs[::-1]:
...         r = f(r)
...     return r
...
>>> @numba.cfunc("double(double)")
... def a(x):
...     return x + 1.0
...
>>> @numba.njit
... def b(x):
...     return x * x
...
>>> composition((a, b), 0.5), 0.5 ** 2 + 1
(1.25, 1.25)
>>> composition((b, a, b, b, a), 0.5), b(a(b(b(a(0.5)))))
(36.75390625, 36.75390625)

在这里,cfunc 编译函数 ab 被视为头等函数对象,因为它们作为参数传递给 Numba JIT 编译函数 composition,也就是说,composition 是独立于其参数函数对象(这些对象收集在输入参数 funcs 中)进行 JIT 编译的。

目前,头等函数对象可以是 Numba cfunc 编译函数、JIT 编译函数以及实现包装器地址协议(WAP,参见下文)的对象,但有以下限制:

上下文

JIT 编译

cfunc 编译

WAP 对象

可用作参数

可调用

可用作项

是*

可返回

命名空间作用域

自动重载

* 头等函数对象序列中的至少一个项必须具有精确的类型。

包装器地址协议 - WAP

包装器地址协议提供了一个 API,用于将任何 Python 对象转换为 Numba JIT 编译函数的头等函数。这假设 Python 对象表示一个编译函数,可以从 Numba JIT 编译函数通过其内存地址(函数指针值)调用。所谓的 WAP 对象必须定义以下两种方法:

__wrapper_address__(self) int

返回头等函数的内存地址。当 Numba JIT 编译函数尝试调用给定的 WAP 实例时,会使用此方法。

signature(self) numba.typing.Signature

返回给定头等函数的签名。当将给定 WAP 实例传递给 Numba JIT 编译函数时,会使用此方法。

此外,WAP 对象可以实现 __call__ 方法。这在从 Numba JIT 编译函数以对象模式调用 WAP 对象时是必需的。

举例来说,让我们在 Numba JIT 编译函数中调用标准数学库函数 cos。加载数学库并使用 ctypes 包后,可以建立 cos 的内存地址

>>> import numba, ctypes, ctypes.util, math
>>> libm = ctypes.cdll.LoadLibrary(ctypes.util.find_library('m'))
>>> class LibMCos(numba.types.WrapperAddressProtocol):
...     def __wrapper_address__(self):
...         return ctypes.cast(libm.cos, ctypes.c_voidp).value
...     def signature(self):
...         return numba.float64(numba.float64)
...
>>> @numba.njit
... def foo(f, x):
...     return f(x)
...
>>> foo(LibMCos(), 0.0)
1.0
>>> foo(LibMCos(), 0.5), math.cos(0.5)
(0.8775825618903728, 0.8775825618903728)

杂项类型

有一些非数值类型不属于其他类别。

类型名称

注释

pyobject

通用 Python 对象

voidptr

原始指针,不能对其执行任何操作

高级类型

对于更高级的声明,你必须显式调用 Numba 提供的辅助函数或类。

警告

此处文档化的 API 不保证稳定。除非必要,建议使用 无签名版本的 @jit 让 Numba 推断参数类型。

推断

numba.typeof(value)

创建准确描述给定 Python value 的 Numba 类型。如果该值在 nopython 模式中不受支持,则会引发 ValueError

>>> numba.typeof(np.empty(3))
array(float64, 1d, C)
>>> numba.typeof((1, 2.0))
(int64, float64)
>>> numba.typeof([0])
reflected list(int64)

NumPy 标量

除了使用 typeof(),非平凡标量(如结构化类型)也可以通过编程方式构建。

numba.from_dtype(dtype)

创建与给定 NumPy dtype 对应的 Numba 类型

>>> struct_dtype = np.dtype([('row', np.float64), ('col', np.float64)])
>>> ty = numba.from_dtype(struct_dtype)
>>> ty
Record([('row', '<f8'), ('col', '<f8')])
>>> ty[:, :]
unaligned array(Record([('row', '<f8'), ('col', '<f8')]), 2d, A)
class numba.types.NPDatetime(unit)

为给定 unit 的 NumPy datetime 创建 Numba 类型。unit 应该是 NumPy 识别的代码(例如 Y, M, D 等)中的字符串。

class numba.types.NPTimedelta(unit)

为给定 unit 的 NumPy timedelta 创建 Numba 类型。unit 应该是 NumPy 识别的代码(例如 Y, M, D 等)中的字符串。

另请参阅

NumPy datetime 单位

数组

class numba.types.Array(dtype, ndim, layout)

创建一个数组类型。dtype 应该是一个 Numba 类型。ndim 是数组的维度数(一个正整数)。layout 是一个字符串,表示数组的布局:A 表示任意布局,C 表示 C 连续,F 表示 Fortran 连续。

可选类型

class numba.optional(typ)

基于底层 Numba 类型 typ 创建一个可选类型。该可选类型将允许 typNone 的任何值。

>>> @jit((optional(intp),))
... def f(x):
...     return x is not None
...
>>> f(0)
True
>>> f(None)
False

类型注解

numba.extending.as_numba_type(py_type)

创建与给定 Python 类型注解 对应的 Numba 类型。如果类型注解无法映射到 Numba 类型,则会引发 TypingError。此函数旨在用于静态编译时评估 Python 类型注解。有关 Python 对象的运行时检查,请参见上文的 typeof

对于任何 numba 类型,as_numba_type(nb_type) == nb_type

>>> numba.extending.as_numba_type(int)
int64
>>> import typing  # the Python library, not the Numba one
>>> numba.extending.as_numba_type(typing.List[float])
ListType[float64]
>>> numba.extending.as_numba_type(numba.int32)
int32

as_numba_type 会自动更新以包含任何 @jitclass

>>> @jitclass
... class Counter:
...     x: int
...
...     def __init__(self):
...         self.x = 0
...
...     def inc(self):
...         old_val = self.x
...         self.x += 1
...         return old_val
...
>>> numba.extending.as_numba_type(Counter)
instance.jitclass.Counter#11bad4278<x:int64>

目前 as_numba_type 仅用于推断 @jitclass 的字段。