字面量类型说明
注意
本文档描述了一项高级功能,旨在克服与类型相关的编译机制的一些限制。
某些功能需要在编译期间根据字面量值进行专门化,以生成 Numba 成功编译所需的类型稳定代码。这可以通过在类型系统中传播字面量值来实现。Numba 将内联字面量值识别为 numba.types.Literal
。例如
def foo(x):
a = 123
return bar(x, a)
Numba 会将 a
的类型推断为 Literal[int](123)
。bar()
的定义随后可以专门化其实现,因为它知道第二个参数是一个值为 123
的 int
类型。
Literal
类型
与 Literal
类型相关的类和方法。
- class numba.types.Literal(*args, **kwargs)
字面量类型的基类。字面量类型在其类型中包含原始 Python 值。
字面量类型应始终通过 literal(val) 函数构建。
- numba.types.literal(value)
返回一个 Literal 实例或抛出 LiteralTypingError 异常
- numba.types.unliteral(lit_type)
从 Literal 类型获取基类型。
- numba.types.maybe_literal(value)
获取值的 Literal 类型或 None。
为字面量类型指定
要在计划进行 JIT 编译的代码中将值指定为 Literal
类型,请使用以下函数
- numba.literally(obj)
强制 Numba 将 obj 解释为字面量值。
obj 必须是字面量或调用函数的参数,其中该参数必须绑定到一个字面量。字面量要求会沿着调用栈向上S传播。
此函数被编译器拦截,以改变编译行为,将相应的函数参数封装为
Literal
类型。在 nopython 模式(解释器模式和对象模式)之外,它不起作用。当前实现以两种方式检测字面量参数
通过一次编译器传递扫描
literally
的使用。literally
被重载以抛出numba.errors.ForceLiteralArg
异常,以信号调度器(dispatcher)以不同的方式处理相应的参数。此模式旨在支持间接使用(通过函数调用)。
此函数的执行语义等同于一个恒等函数。
请参阅 numba/tests/test_literal_dispatch.py 查看示例。
代码示例
numba/tests/doc_examples/test_literally_usage.py
中的 test_literally_usage
1 import numba
2
3 def power(x, n):
4 raise NotImplementedError
5
6 @numba.extending.overload(power)
7 def ov_power(x, n):
8 if isinstance(n, numba.types.Literal):
9 # only if `n` is a literal
10 if n.literal_value == 2:
11 # special case: square
12 print("square")
13 return lambda x, n: x * x
14 elif n.literal_value == 3:
15 # special case: cubic
16 print("cubic")
17 return lambda x, n: x * x * x
18 else:
19 # If `n` is not literal, request literal dispatch
20 return lambda x, n: numba.literally(n)
21
22 print("generic")
23 return lambda x, n: x ** n
24
25 @numba.njit
26 def test_power(x, n):
27 return power(x, n)
28
29 # should print "square" and "9"
30 print(test_power(3, 2))
31
32 # should print "cubic" and "27"
33 print(test_power(3, 3))
34
35 # should print "generic" and "81"
36 print(test_power(3, 4))
37
内部细节
在内部,编译器会抛出 ForceLiteralArgs
异常,以信号调度器(dispatcher)使用 Literal
类型封装指定的参数。
- class numba.errors.ForceLiteralArg(arg_indices, fold_arguments=None, loc=None)
一个伪异常,用于信号调度器(dispatcher)将参数按字面量类型化
- 属性
- 请求参数frozenset[int]
参数的请求位置。
- __init__(arg_indices, fold_arguments=None, loc=None)
- 参数
- 参数索引Sequence[int]
参数的请求位置。
- fold_arguments: 可调用对象
一个函数
(tuple, dict) -> tuple
,它绑定并展平args
和kwargs
。- 位置numba.ir.Loc or None
- __or__(other)
与 self.combine(other) 相同
- combine(other)
通过对 requested_args 执行或操作返回一个新实例。
在扩展中
@overload
扩展可以在实现体内部使用 literally
,就像在普通的 jit-代码中一样。
字面量要求的显式处理可以通过使用以下内容实现
- class numba.extending.SentryLiteralArgs(literal_args)
- 参数
- 字面量参数Sequence[str]
字面量参数的名称序列
示例
以下行
>>> SentryLiteralArgs(literal_args).for_pysig(pysig).bind(*args, **kwargs)
等同于
>>> sentry_literal_args(pysig, literal_args, args, kwargs)
- for_function(func)
将哨兵(sentry)绑定到 func 的签名。
- 参数
- 函数Function
一个 Python 函数。
- 返回
- 对象BoundLiteralArgs
- for_pysig(pysig)
将哨兵(sentry)绑定到给定的签名 pysig。
- 参数
- pysiginspect.Signature
- 返回
- 对象BoundLiteralArgs
- class numba.extending.BoundLiteralArgs(pysig, literal_args)
此类通常由 SentryLiteralArgs 创建。
- bind(*args, **kwargs)
绑定到参数类型。
- numba.extending.sentry_literal_args(pysig, literal_args, args, kwargs)
确保给定参数类型(在 args 和 kwargs 中)被字面量类型化,适用于具有 Python 签名 pysig 和 literal_args 中字面量参数名称列表的函数。
此外,这与以下内容相同
SentryLiteralArgs(literal_args).for_pysig(pysig).bind(*args, **kwargs)