字面量类型说明

注意

本文档描述了一项高级功能,旨在克服与类型相关的编译机制的一些限制。

某些功能需要在编译期间根据字面量值进行专门化,以生成 Numba 成功编译所需的类型稳定代码。这可以通过在类型系统中传播字面量值来实现。Numba 将内联字面量值识别为 numba.types.Literal。例如

def foo(x):
    a = 123
    return bar(x, a)

Numba 会将 a 的类型推断为 Literal[int](123)bar() 的定义随后可以专门化其实现,因为它知道第二个参数是一个值为 123int 类型。

Literal 类型

Literal 类型相关的类和方法。

class numba.types.Literal(*args, **kwargs)

字面量类型的基类。字面量类型在其类型中包含原始 Python 值。

字面量类型应始终通过 literal(val) 函数构建。

numba.types.literal(value)

返回一个 Literal 实例或抛出 LiteralTypingError 异常

numba.types.unliteral(lit_type)

从 Literal 类型获取基类型。

numba.types.maybe_literal(value)

获取值的 Literal 类型或 None。

为字面量类型指定

要在计划进行 JIT 编译的代码中将值指定为 Literal 类型,请使用以下函数

numba.literally(obj)

强制 Numba 将 obj 解释为字面量值。

obj 必须是字面量或调用函数的参数,其中该参数必须绑定到一个字面量。字面量要求会沿着调用栈向上S传播。

此函数被编译器拦截,以改变编译行为,将相应的函数参数封装为 Literal 类型。在 nopython 模式(解释器模式和对象模式)之外,它不起作用

当前实现以两种方式检测字面量参数

  1. 通过一次编译器传递扫描 literally 的使用。

  2. literally 被重载以抛出 numba.errors.ForceLiteralArg 异常,以信号调度器(dispatcher)以不同的方式处理相应的参数。此模式旨在支持间接使用(通过函数调用)。

此函数的执行语义等同于一个恒等函数。

请参阅 numba/tests/test_literal_dispatch.py 查看示例。

代码示例

来自 numba/tests/doc_examples/test_literally_usage.py 中的 test_literally_usage
 1        import numba
 2
 3        def power(x, n):
 4            raise NotImplementedError
 5
 6        @numba.extending.overload(power)
 7        def ov_power(x, n):
 8            if isinstance(n, numba.types.Literal):
 9                # only if `n` is a literal
10                if n.literal_value == 2:
11                    # special case: square
12                    print("square")
13                    return lambda x, n: x * x
14                elif n.literal_value == 3:
15                    # special case: cubic
16                    print("cubic")
17                    return lambda x, n: x * x * x
18            else:
19                # If `n` is not literal, request literal dispatch
20                return lambda x, n: numba.literally(n)
21
22            print("generic")
23            return lambda x, n: x ** n
24
25        @numba.njit
26        def test_power(x, n):
27            return power(x, n)
28
29        # should print "square" and "9"
30        print(test_power(3, 2))
31
32        # should print "cubic" and "27"
33        print(test_power(3, 3))
34
35        # should print "generic" and "81"
36        print(test_power(3, 4))
37

内部细节

在内部,编译器会抛出 ForceLiteralArgs 异常,以信号调度器(dispatcher)使用 Literal 类型封装指定的参数。

class numba.errors.ForceLiteralArg(arg_indices, fold_arguments=None, loc=None)

一个伪异常,用于信号调度器(dispatcher)将参数按字面量类型化

属性
请求参数frozenset[int]

参数的请求位置。

__init__(arg_indices, fold_arguments=None, loc=None)
参数
参数索引Sequence[int]

参数的请求位置。

fold_arguments: 可调用对象

一个函数 (tuple, dict) -> tuple,它绑定并展平 argskwargs

位置numba.ir.Loc or None
__or__(other)

与 self.combine(other) 相同

combine(other)

通过对 requested_args 执行或操作返回一个新实例。

在扩展中

@overload 扩展可以在实现体内部使用 literally,就像在普通的 jit-代码中一样。

字面量要求的显式处理可以通过使用以下内容实现

class numba.extending.SentryLiteralArgs(literal_args)
参数
字面量参数Sequence[str]

字面量参数的名称序列

示例

以下行

>>> SentryLiteralArgs(literal_args).for_pysig(pysig).bind(*args, **kwargs)

等同于

>>> sentry_literal_args(pysig, literal_args, args, kwargs)
for_function(func)

将哨兵(sentry)绑定到 func 的签名。

参数
函数Function

一个 Python 函数。

返回
对象BoundLiteralArgs
for_pysig(pysig)

将哨兵(sentry)绑定到给定的签名 pysig

参数
pysiginspect.Signature
返回
对象BoundLiteralArgs
class numba.extending.BoundLiteralArgs(pysig, literal_args)

此类通常由 SentryLiteralArgs 创建。

bind(*args, **kwargs)

绑定到参数类型。

numba.extending.sentry_literal_args(pysig, literal_args, args, kwargs)

确保给定参数类型(在 argskwargs 中)被字面量类型化,适用于具有 Python 签名 pysigliteral_args 中字面量参数名称列表的函数。

此外,这与以下内容相同

SentryLiteralArgs(literal_args).for_pysig(pysig).bind(*args, **kwargs)