支持的 NumPy 功能

Numba 的目标之一是与 NumPy 无缝集成。NumPy 数组为同构数据集提供了高效的存储方法。NumPy dtypes 提供了编译时有用的类型信息,并且内存中潜在大量数据的规则、结构化存储为代码生成提供了理想的内存布局。Numba 擅长生成在 NumPy 数组之上执行的代码。

Numba 对 NumPy 的支持体现在多个方面

  • Numba 理解对 NumPy ufuncs 的调用,并能为其中许多生成等效的本地代码。

  • NumPy 数组在 Numba 中直接受支持。对 NumPy 数组的访问非常高效,因为在可能的情况下,索引会被简化为直接内存访问。

  • Numba 能够生成 ufuncsgufuncs。这意味着可以在 Python 中实现 ufuncs 和 gufuncs,获得与使用 NumPy C API 在 C 扩展模块中实现的 ufuncs/gufuncs 相当的速度。

除非另有说明,以下部分将重点介绍 nopython 模式中支持的 NumPy 功能。

标量类型

Numba 支持以下 NumPy 标量类型

  • 整数:所有有符号或无符号整数,宽度最大为 64 位

  • 布尔值

  • 实数: 单精度 (32 位) 和双精度 (64 位) 实数

  • 复数: 单精度 (2x32 位) 和双精度 (2x64 位) 复数

  • 日期时间和时间戳: 任何单位

  • 字符序列 (但对其不提供任何操作)

  • 结构化标量: 由上述任何类型和上述类型数组组成的结构化标量

以下标量类型和功能不受支持

  • 任意 Python 对象

  • 半精度和扩展精度 实数和复数

  • 嵌套结构化标量 结构化标量的字段不能包含其他结构化标量

NumPy 标量上支持的操作与 intfloat 等效内置类型上的操作几乎相同。您可以使用类型的构造函数从不同类型或宽度进行转换。此外,您可以使用 view(np.<dtype>) 方法对所有相同宽度的 intfloat 类型进行位转换。但是,您必须在 JIT 编译函数中使用 NumPy 构造函数定义标量。例如,以下代码将起作用

>>> import numpy as np
>>> from numba import njit
>>> @njit
... def bitcast():
...     i = np.int64(-1)
...     print(i.view(np.uint64))
...
>>> bitcast()
18446744073709551615

而以下代码将不起作用

>>> import numpy as np
>>> from numba import njit
>>> @njit
... def bitcast(i):
...     print(i.view(np.uint64))
...
>>> bitcast(np.int64(-1))
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
    ...
TypingError: Failed in nopython mode pipeline (step: ensure IR is legal prior to lowering)
'view' can only be called on NumPy dtypes, try wrapping the variable with 'np.<dtype>()'

File "<ipython-input-3-fc40aaab84c4>", line 3:
def bitcast(i):
    print(i.view(np.uint64))

结构化标量支持属性的获取和设置,以及使用常量字符串进行成员查找。存储在局部或全局元组中的字符串被视为常量字符串,可用于成员查找。

import numpy as np
from numba import njit

arr = np.array([(1, 2)], dtype=[('a1', 'f8'), ('a2', 'f8')])
fields_gl = ('a1', 'a2')

@njit
def get_field_sum(rec):
    fields_lc = ('a1', 'a2')
    field_name1 = fields_lc[0]
    field_name2 = fields_gl[1]
    return rec[field_name1] + rec[field_name2]

get_field_sum(arr[0])  # returns 3

也可以将局部或全局元组与 literal_unroll 结合使用

import numpy as np
from numba import njit, literal_unroll

arr = np.array([(1, 2)], dtype=[('a1', 'f8'), ('a2', 'f8')])
fields_gl = ('a1', 'a2')

@njit
def get_field_sum(rec):
    out = 0
    for f in literal_unroll(fields_gl):
        out += rec[f]
    return out

get_field_sum(arr[0])   # returns 3

记录子类型

警告

这是一个实验性功能。

Numba 允许结构化标量的 宽度子类型。例如,dtype([('a', 'f8'), ('b', 'i8')]) 将被视为 dtype([('a', 'f8')]) 的子类型,因为后者是前者的严格子集,即字段 a 具有相同的类型并位于两种类型中的相同位置。子类型关系在某些情况下很重要,例如不允许对特定输入进行编译,但该输入是另一个允许类型的子类型。

import numpy as np
from numba import njit, typeof
from numba.core import types
record1 = np.array([1], dtype=[('a', 'f8')])[0]
record2 = np.array([(2,3)], dtype=[('a', 'f8'), ('b', 'f8')])[0]

@njit(types.float64(typeof(record1)))
def foo(rec):
    return rec['a']

foo(record1)
foo(record2)

如果没有子类型化,最后一行将失败。有了子类型化,将不会触发新的编译,但 record1 的编译函数将用于 record2

另请参阅

NumPy 标量 参考。

数组类型

支持上述任何标量类型的 NumPy 数组,无论其形状或布局如何。

注意

NumPy MaskedArrays 不受支持。

数组访问

数组支持常规迭代。支持完整的基本索引和切片,以及传递 None / np.newaxis 作为附加结果维度的索引。还支持高级索引的一个子集:只允许一个高级索引,并且它必须是一维数组(它也可以与任意数量的基本索引结合使用)。

另请参阅

NumPy 索引 参考。

结构化数组访问

Numba 目前支持通过属性以及通过获取和设置来访问结构化数组中单个元素的字段。这略微超出了 NumPy API 的范围,NumPy API 只允许通过获取和设置来访问字段。例如

from numba import njit
import numpy as np

record_type = np.dtype([("ival", np.int32), ("fval", np.float64)], align=True)

def f(rec):
    value = 2.5
    rec[0].ival = int(value)
    rec[0].fval = value
    return rec

arr = np.ones(1, dtype=record_type)

cfunc = njit(f)

# Works
print(cfunc(arr))

# Does not work
print(f(arr))

上述代码会产生以下输出

[(2, 2.5)]
Traceback (most recent call last):
  File "repro.py", line 22, in <module>
    print(f(arr))
  File "repro.py", line 9, in f
    rec[0].ival = int(value)
AttributeError: 'numpy.void' object has no attribute 'ival'

Numba 编译的函数版本可以执行,但纯 Python 版本会因为不支持的属性访问而引发错误。

注意

此行为最终将被弃用和删除。

属性

支持以下 NumPy 数组属性

flags 对象

flags 属性返回的对象支持 contiguousc_contiguousf_contiguous 属性。

flat 对象

flat 属性返回的对象支持迭代和索引,但请注意:在非 C 连续数组上,索引速度非常慢。

realimag 属性

NumPy 支持这些属性,无论 dtype 如何,但 Numba 选择限制其支持以避免潜在的用户错误。对于数字 dtypes,Numba 遵循 NumPy 的行为。real 属性返回复数数组实部的视图,对于其他数字 dtypes,其行为类似于恒等函数。imag 属性返回复数数组虚部的视图,对于其他数字 dtypes,它返回一个具有相同形状和 dtype 的零数组。对于非数字 dtypes,包括所有结构化/记录 dtypes,使用这些属性将导致编译时 (TypingError) 错误。此行为与 NumPy 不同,但选择此行为是为了避免与这些属性重叠的字段名可能造成的混淆。

计算

支持以下 NumPy 数组方法的基本形式(不带任何可选参数)

相应的顶级 NumPy 函数(如 numpy.prod())同样受支持。

其他方法

支持以下 NumPy 数组方法

  • argmax()(支持 axis 关键字参数)。

  • argmin()(支持 axis 关键字参数)。

  • numpy.argpartition()(仅支持前 2 个参数)

  • argsort()kind 关键字参数支持 'quicksort''mergesort' 值)

  • astype()(仅支持 1 个参数的形式)

  • copy()(无参数)

  • dot()(仅支持 1 个参数的形式)

  • flatten()(无 order 参数;仅 ‘C’ order)

  • item()(无参数)

  • itemset()(仅支持 1 个参数的形式)

  • ptp()(无参数)

  • ravel()(无 order 参数;仅 ‘C’ order)

  • repeat()(无 axis 参数)

  • reshape()(仅支持 1 个参数的形式)

  • sort()(无参数)

  • sum()(带或不带 axis 和/或 dtype 参数。)

    • axis 仅支持 integer 值。

    • 如果 axis 参数是编译时常量,则支持所有有效值。超出范围的值将在编译时导致 LoweringError

    • 如果 axis 参数不是编译时常量,则仅支持 0 到 3 的值。超出范围的值将导致运行时异常。

    • 所有数字 dtypes 都支持 dtype 参数。timedelta 数组可以用作输入数组,但 timedelta 不支持作为 dtype 参数。

    • 当给定 dtype 时,它决定了内部累加器的类型。当未给定 dtype 时,将根据输入数组的 dtype 自动选择,大部分遵循与 NumPy 相同的规则。但是,在 64 位 Windows 上,Numba 对整数输入使用 64 位累加器(int32 输入为 int64uint32 输入为 uint64),而 NumPy 在这些情况下将使用 32 位累加器。

  • transpose()

  • view()(仅支持 1 个参数的形式)

  • __contains__()

在适用情况下,相应的顶级 NumPy 函数(如 numpy.argmax())也同样受支持。

警告

排序可能比 NumPy 的实现稍慢。

函数

线性代数

支持对浮点数和复数的 1D 和 2D 连续数组进行基本线性代数运算

注意

这些函数的实现需要安装 SciPy。

归约

支持以下归约函数

其他函数

支持以下顶级函数

以下构造函数受支持,它们都可以接受数字输入(用于构造标量)或序列(用于构造数组)

  • numpy.bool_

  • numpy.complex64

  • numpy.complex128

  • numpy.float32

  • numpy.float64

  • numpy.int8

  • numpy.int16

  • numpy.int32

  • numpy.int64

  • numpy.intc

  • numpy.intp

  • numpy.uint8

  • numpy.uint16

  • numpy.uint32

  • numpy.uint64

  • numpy.uintc

  • numpy.uintp

以下机器参数类受支持,具有所有纯数字属性

字面数组

Python 和 Numba 都没有实际的数组字面量,但您可以通过在嵌套元组上调用 numpy.array() 来构造任意数组

a = numpy.array(((a, b, c), (d, e, f)))

(Numba 尚不支持嵌套列表)

模块

random

生成器对象

Numba 支持 numpy.random.Generator() 对象。从版本 0.56 开始,用户可以将单个 NumPy Generator 对象传递给 Numba 函数,并在函数内部使用其方法。由于使用了与 NumPy 相同的算法进行随机数生成,因此在相同参数下,NumPy 和 Numba 生成的随机数保持一致(NumPy Generator 方法的相同文档说明也适用)。当前的 Numba 对 Generator 的支持不是线程安全的,因此我们不建议在具有并行执行逻辑的方法中使用 Generator 方法。

注意

NumPy 的 Generator 对象依赖于 BitGenerator 来管理状态并生成随机位,然后将其转换为有用分布的随机值。Numba 将 解箱 (unbox) Generator 对象,并将使用 NumPy 的 ctypes 接口绑定来维护对底层 BitGenerator 对象的引用。因此,Generator 对象可以跨越 JIT 边界,其函数可以在 Numba-Jit 代码中使用。请注意,由于仅维护对 BitGenerator 对象的引用,因此对 Numba 代码之外的特定 Generator 对象状态的任何更改都将影响 Numba 代码内部 Generator 的状态。

x = np.random.default_rng(1)
y = np.random.default_rng(1)

size = 10

@numba.njit
def do_stuff(gen):
    return gen.random(size=int(size / 2))

original = x.random(size=size)
# [0.51182162 0.9504637  0.14415961 0.94864945 0.31183145
#  0.42332645 0.82770259 0.40919914 0.54959369 0.02755911]

numba_func_res = do_stuff(y)
# [0.51182162 0.9504637  0.14415961 0.94864945 0.31183145]

after_numba = y.random(size=int(size / 2))
# [0.42332645 0.82770259 0.40919914 0.54959369 0.02755911]

支持以下 Generator 方法

  • numpy.random.Generator().beta()

  • numpy.random.Generator().chisquare()

  • numpy.random.Generator().exponential()

  • numpy.random.Generator().f()

  • numpy.random.Generator().gamma()

  • numpy.random.Generator().geometric()

  • numpy.random.Generator().integers()lowhigh 都是必需参数。目前不支持 low 和 high 的数组值。)

  • numpy.random.Generator().laplace()

  • numpy.random.Generator().logistic()

  • numpy.random.Generator().lognormal()

  • numpy.random.Generator().logseries()(接受浮点值以及转换为浮点数的数据类型。目前不支持 p 的数组值。)

  • numpy.random.Generator().negative_binomial()

  • numpy.random.Generator().noncentral_chisquare()(接受浮点值以及转换为浮点数的数据类型。目前不支持 dfnumnonc 的数组值。)

  • numpy.random.Generator().noncentral_f()(接受浮点值以及转换为浮点数的数据类型。目前不支持 dfnumdfdennonc 的数组值。)

  • numpy.random.Generator().normal()

  • numpy.random.Generator().pareto()

  • numpy.random.Generator().permutation()(仅接受 NumPy ndarray 和整数。)

  • numpy.random.Generator().poisson()

  • numpy.random.Generator().power()

  • numpy.random.Generator().random()

  • numpy.random.Generator().rayleigh()

  • numpy.random.Generator().shuffle()(仅接受 NumPy ndarray。)

  • numpy.random.Generator().standard_cauchy()

  • numpy.random.Generator().standard_exponential()

  • numpy.random.Generator().standard_gamma()

  • numpy.random.Generator().standard_normal()

  • numpy.random.Generator().standard_t()

  • numpy.random.Generator().triangular()

  • numpy.random.Generator().uniform()

  • numpy.random.Generator().wald()

  • numpy.random.Generator().weibull()

  • numpy.random.Generator().zipf()

注意

由于编译器指令选择的差异,与 NumPy 相比,在 32 位架构以及 linux-aarch64 和 linux-ppc64le 平台上,可能存在 1000 个 ULPs 级别的差异。对于 Linux-x86_64、Windows-x86_64 和 macOS,这些差异不太明显(10 个 ULPs 级别),但不保证遵循异常模式,并且在某些情况下可能会增加。

这种差异不太可能影响随机数生成的“质量”,因为它们是由于使用融合乘加而不是先乘后加时发生的舍入变化引起的。

RandomState 和旧式随机数生成

Numba 支持 numpy.random 模块的顶级函数,但不支持创建单个 RandomState 实例。使用与 标准 random 模块 相同的算法(因此适用相同的注意事项),但具有独立的内部状态:从一个生成器播种或抽取数字不会影响另一个生成器。

支持以下函数。

初始化

警告

从解释型代码(包括 对象模式 代码)调用 numpy.random.seed() 将播种 NumPy 随机生成器,而不是 Numba 随机生成器。要播种 Numba 随机生成器,请参见下面的示例。

from numba import njit
import numpy as np

@njit
def seed(a):
    np.random.seed(a)

@njit
def rand():
    return np.random.rand()


# Incorrect seeding
np.random.seed(1234)
print(rand())

np.random.seed(1234)
print(rand())

# Correct seeding
seed(1234)
print(rand())

seed(1234)
print(rand())

排列

stride_tricks

支持 numpy.lib.stride_tricks 模块中的以下函数

  • as_strided()strides 参数是强制性的,不支持 subok 参数)

  • sliding_window_view()(不支持 subok 参数,writeable 参数不支持,返回的视图始终可写)

标准 ufuncs

Numba 的目标之一是让所有 NumPy 中的标准 ufuncs 都能被 Numba 理解。当编译函数时发现受支持的 ufunc 时,Numba 会将该 ufunc 映射到等效的本地代码。这允许在 nopython 模式下编译的 Numba 代码中使用这些 ufuncs。

限制

目前,只有部分标准 ufuncs 在 nopython 模式下工作。以下是 Numba 所知的所有不同标准 ufuncs 的列表,排序方式与 NumPy 文档中相同。

数学运算

UFUNC

模式

名称

对象模式

nopython 模式

add

subtract

multiply

divide

logaddexp

logaddexp2

true_divide

floor_divide

negative

power

float_power

remainder

mod

fmod

divmod (*)

abs

absolute

fabs

rint

sign

conj

exp

exp2

log

log2

log10

expm1

log1p

sqrt

square

cbrt

reciprocal

conjugate

gcd

lcm

(*) 不支持 timedelta 类型

三角函数

UFUNC

模式

名称

对象模式

nopython 模式

sin

cos

tan

arcsin

arccos

arctan

arctan2

hypot

sinh

cosh

tanh

arcsinh

arccosh

arctanh

deg2rad

rad2deg

degrees

radians

位操作函数

UFUNC

模式

名称

对象模式

nopython 模式

bitwise_and

bitwise_or

bitwise_xor

bitwise_not

invert

left_shift

right_shift

比较函数

UFUNC

模式

名称

对象模式

nopython 模式

greater

greater_equal

less

less_equal

not_equal

equal

logical_and

logical_or

logical_xor

logical_not

maximum

minimum

fmax

fmin

浮点函数

UFUNC

模式

名称

对象模式

nopython 模式

isfinite

isinf

isnan

signbit

copysign

nextafter

modf

ldexp

frexp

floor

ceil

trunc

spacing

日期时间函数

UFUNC

模式

名称

对象模式

nopython 模式

isnat