为与其他语言一起使用而编译Python函数
CUDA 内置目标弃用通知
Numba 内置的 CUDA 目标已被弃用,进一步的开发已转移到 NVIDIA numba-cuda 包。请参阅 内置 CUDA 目标的弃用和维护状态。
Numba 可以将 Python 代码编译为 PTX 或 LTO-IR,以便可以将 Python 函数集成到用其他语言(例如 C/C++)编写的 CUDA 代码中。它通常用于在库或应用程序的上下文中支持用 Python 编写的用户定义函数。
编译 API 可以在没有 GPU 的情况下使用,因为它不使用任何驱动程序函数,并在此过程中避免初始化 CUDA。它通过以下函数调用:
- numba.cuda.compile(pyfunc, sig, debug=False, lineinfo=False, device=True, fastmath=False, cc=None, opt=True, abi='c', abi_info=None, output='ptx')
将 Python 函数编译为给定参数类型的 PTX 或 LTO-IR。
- 参数
pyfunc – 要编译的 Python 函数。
sig – 表示函数输入和输出类型的签名。如果这是一个不带返回类型的参数类型元组,则此函数返回推断的返回类型。如果传递了包含返回类型的签名,则编译后的代码将包含从推断返回类型到指定返回类型的转换,并且此函数将返回指定的返回类型。
debug (bool) – 是否在编译后的代码中包含调试信息。
lineinfo (bool) – 是否包含从编译代码到源代码的行映射。通常这与优化代码一起使用(因为调试模式会自动包含此信息),所以我们希望 LLVM IR 中有调试信息,但在最终输出中只有行映射。
device (bool) – 是否编译设备函数。
fastmath (bool) – 是否启用快速数学标志(ftz=1, prec_sqrt=0, prec_div=, fma=1)。
cc (tuple) – 要编译的计算能力,格式为元组
(MAJOR, MINOR)
。默认为(5, 0)
。opt (bool) – 启用优化。默认为
True
。abi (str) – 编译函数的 ABI —
"numba"
或"c"
。请注意,Numba ABI 不被视为稳定。目前 C ABI 仅支持设备函数。abi_info (dict) – 包含 ABI 特定选项的字典。
"c"
ABI 支持一个选项"abi_name"
,用于提供包装函数的名称。"numba"
ABI 没有选项。output (str) – 要生成的输出类型,
"ptx"
或"ltoir"
。
- 返回
(code, resty):编译后的代码和推断的返回类型。
- 返回类型
如果设备可用且需要当前设备计算能力的编译代码(例如在使用 Numba 构建 JIT 编译工作流时),可以使用 compile_for_current_device
函数
- numba.cuda.compile_for_current_device(pyfunc, sig, debug=False, lineinfo=False, device=True, fastmath=False, opt=True, abi='c', abi_info=None, output='ptx')
将 Python 函数编译为给定签名的 PTX 或 LTO-IR,用于当前设备的计算能力。此函数调用
compile()
并为当前设备设置适当的cc
值。
大多数用户应该使用上面描述的两个函数;为了向后兼容现有用例,还提供了以下函数:
- numba.cuda.compile_ptx(pyfunc, sig, debug=False, lineinfo=False, device=False, fastmath=False, cc=None, opt=True, abi='numba', abi_info=None)
将 Python 函数编译为给定签名的 PTX。请参阅
compile()
。此函数的默认设置是使用 Numba ABI 编译内核,而不是compile()
默认使用 C ABI 编译设备函数。
- numba.cuda.compile_ptx_for_current_device(pyfunc, sig, debug=False, lineinfo=False, device=False, fastmath=False, opt=True, abi='numba', abi_info=None)
为当前设备的计算能力将 Python 函数编译为给定签名的 PTX。请参阅
compile_ptx()
。
使用 C ABI
Numba 内部使用自己的 ABI — 如 设备函数 ABI 中所述,不带 extern "C"
修饰符。调用 Numba ABI 设备函数需要解决三个问题:
函数的名称将根据 Numba 的 ABI 规则进行修饰——这些规则基于 Itanium C++ ABI 规则,但超出了其规范。
Python 返回值应存储在作为第一个参数传递的指针值中。
编译函数的返回值将包含一个状态码,而不是函数的返回值。对于在 Numba 外部使用 Numba 编译的函数,这通常可以忽略。
解决所有这些问题的简单方法是改为使用 C ABI 编译设备函数。这会产生以下结果:
编译代码中设备函数的名称可以控制。默认情况下,它将与 Python 中的函数名称匹配,因此很容易确定。这是函数的
__name__
,而不是__qualname__
,因为__qualname__
编码了额外的作用域信息,这会使函数名称难以预测,并且在很多情况下,在 C 中是一个非法标识符。Python 代码的返回值直接放置在编译函数的返回值中。
状态码被忽略/不报告,因此不需要处理。
如果需要指定编译函数的名称,可以通过在 abi_info
字典中传递名称,键为 'abi_name'
来控制。
使用 compile()
和 compile_for_current_device()
函数时,默认使用 C ABI 进行编译。compile_ptx()
和 compile_ptx_for_current_device()
函数默认为 Numba ABI,以保持与现有用例的兼容性。
C 和 Numba ABI 示例
以下函数
def add(x, y):
return x + y
例如,使用以下方式编译为 Numba ABI:
ptx, resty = cuda.compile_ptx(add, int32(int32, int32), device=True)
结果是函数原型为:
.visible .func (.param .b32 func_retval0) _ZN8__main__3addB2v1B94cw51cXTLSUwv1sCUt9Uw1VEw0NRRQPKzLTg4gaGKFsG2oMQGEYakJSQB1PQBk0Bynm21OiwU1a0UoLGhDpQE8oxrNQE_3dEii(
.param .b64 _ZN8__main__3addB2v1B94cw51cXTLSUwv1sCUt9Uw1VEw0NRRQPKzLTg4gaGKFsG2oMQGEYakJSQB1PQBk0Bynm21OiwU1a0UoLGhDpQE8oxrNQE_3dEii_param_0,
.param .b32 _ZN8__main__3addB2v1B94cw51cXTLSUwv1sCUt9Uw1VEw0NRRQPKzLTg4gaGKFsG2oMQGEYakJSQB1PQBk0Bynm21OiwU1a0UoLGhDpQE8oxrNQE_3dEii_param_1,
.param .b32 _ZN8__main__3addB2v1B94cw51cXTLSUwv1sCUt9Uw1VEw0NRRQPKzLTg4gaGKFsG2oMQGEYakJSQB1PQBk0Bynm21OiwU1a0UoLGhDpQE8oxrNQE_3dEii_param_2
)
请注意,有三个参数,分别用于返回值的指针、x
和 y
。名称以一种在 Numba 内部之外难以预测的方式被修饰。
使用以下方式编译为 C ABI:
ptx, resty = cuda.compile_ptx(add, int32(int32, int32), device=True, abi="c")
结果是以下 PTX 原型:
.visible .func (.param .b32 func_retval0) add(
.param .b32 add_param_0,
.param .b32 add_param_1
)
函数名称与 Python 源代码函数名称匹配,并且只有两个参数,分别用于 x
和 y
。函数的结果直接放置在返回值中。
add.s32 %r3, %r2, %r1;
st.param.b32 [func_retval0+0], %r3;
为了区分编译后的 add()
函数的不同变体,以下示例在 abi_info
字典中指定了其 ABI 名称:
ptx, resty = cuda.compile_ptx(add, float32(float32, float32), device=True,
abi="c", abi_info={"abi_name": "add_f32"})
结果是 PTX 原型:
.visible .func (.param .b32 func_retval0) add_f32(
.param .b32 add_f32_param_0,
.param .b32 add_f32_param_1
)
这将不会与其他名称的定义(例如,上面 int32
的变体)冲突。