自定义编译器

警告

自定义管道功能仅供专家使用。修改编译器行为可能会使 Numba 源代码中的内部假设失效。

对于希望扩展或修改编译器行为的库开发者,可以通过继承 numba.compiler.CompilerBase 来定义一个自定义编译器。默认的 Numba 编译器定义为 numba.compiler.Compiler,它实现了 .define_pipelines() 方法,该方法添加了 nopython 模式object 模式解释模式 管道。为方便起见,这三个管道在 numba.compiler.DefaultPassBuilder 中通过以下方法定义:

  • .define_nopython_pipeline()

  • .define_objectmode_pipeline()

  • .define_interpreted_pipeline()

分别。

要使用 CompilerBase 的自定义子类,请将其作为 pipeline_class 关键字参数提供给 @jit 装饰器。通过这样做,自定义管道的效果仅限于被装饰的函数。

实现编译器通道

Numba 使得实现新的编译器通道成为可能,其通过使用类似于 LLVM 的 API 来实现。下面演示了所涉及的基本过程。

编译器通道类

所有通道都必须继承自 numba.compiler_machinery.CompilerPass,常用的子类有:

  • numba.compiler_machinery.FunctionPass 用于描述在函数级别操作并可能修改 IR 状态的通道。

  • numba.compiler_machinery.AnalysisPass 用于描述仅执行分析的通道。

  • numba.compiler_machinery.LoweringPass 用于描述仅执行降低操作的通道。

在此示例中,将实现一个新的编译器通道,它将重写所有 ir.Const(x) 节点,其中 xnumbers.Number 的子类,使 x 的值增加一。此通道除作为教学工具外,没有其他用途!

numba.compiler_machinery.FunctionPass 适用于建议的通道行为,因此它是新通道的基类。此外,定义了一个 run_pass 方法来执行工作(此方法是抽象的,所有编译器通道都必须实现它)。

首先是新类

from numba import njit
from numba.core import ir
from numba.core.compiler import CompilerBase, DefaultPassBuilder
from numba.core.compiler_machinery import FunctionPass, register_pass
from numba.core.untyped_passes import IRProcessing
from numbers import Number

# Register this pass with the compiler framework, declare that it will not
# mutate the control flow graph and that it is not an analysis_only pass (it
# potentially mutates the IR).
@register_pass(mutates_CFG=False, analysis_only=False)
class ConstsAddOne(FunctionPass):
    _name = "consts_add_one" # the common name for the pass

    def __init__(self):
        FunctionPass.__init__(self)

    # implement method to do the work, "state" is the internal compiler
    # state from the CompilerBase instance.
    def run_pass(self, state):
        func_ir = state.func_ir # get the FunctionIR object
        mutated = False # used to record whether this pass mutates the IR
        # walk the blocks
        for blk in func_ir.blocks.values():
            # find the assignment nodes in the block and walk them
            for assgn in blk.find_insts(ir.Assign):
                # if an assignment value is a ir.Consts
                if isinstance(assgn.value, ir.Const):
                    const_val = assgn.value
                    # if the value of the ir.Const is a Number
                    if isinstance(const_val.value, Number):
                        # then add one!
                        const_val.value += 1
                        mutated |= True
        return mutated # return True if the IR was mutated, False if not.

另请注意,该类必须使用 @register_pass 在 Numba 的编译器机制中注册。这部分是为了允许声明通道是否修改控制流图以及它是否仅是分析通道。

接下来,基于现有的 numba.compiler.CompilerBase 定义一个新的编译器。编译器管道通过使用现有管道定义,并且将上面声明的新通道添加到在 IRProcessing 通道之后运行。

class MyCompiler(CompilerBase): # custom compiler extends from CompilerBase

    def define_pipelines(self):
        # define a new set of pipelines (just one in this case) and for ease
        # base it on an existing pipeline from the DefaultPassBuilder,
        # namely the "nopython" pipeline
        pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
        # Add the new pass to run after IRProcessing
        pm.add_pass_after(ConstsAddOne, IRProcessing)
        # finalize
        pm.finalize()
        # return as an iterable, any number of pipelines may be defined!
        return [pm]

最后,更新调用点的 @njit 装饰器,以使用新定义的编译管道。

@njit(pipeline_class=MyCompiler) # JIT compile using the custom compiler
def foo(x):
    a = 10
    b = 20.2
    c = x + a + b
    return c

print(foo(100)) # 100 + 10 + 20.2 (+ 1 + 1), extra + 1 + 1 from the rewrite!

调试编译器通道

观察 IR 变化

通常,能够看到一个通道对 IR 所做的更改会很有用。Numba 通过使用环境变量 NUMBA_DEBUG_PRINT_AFTER 方便地允许这样做。在上述通道的情况下,使用 NUMBA_DEBUG_PRINT_AFTER="ir_processing,consts_add_one" 运行示例代码会得到:

----------------------------nopython: ir_processing-----------------------------
label 0:
    x = arg(0, name=x)                       ['x']
    $const0.1 = const(int, 10)               ['$const0.1']
    a = $const0.1                            ['$const0.1', 'a']
    del $const0.1                            []
    $const0.2 = const(float, 20.2)           ['$const0.2']
    b = $const0.2                            ['$const0.2', 'b']
    del $const0.2                            []
    $0.5 = x + a                             ['$0.5', 'a', 'x']
    del x                                    []
    del a                                    []
    $0.7 = $0.5 + b                          ['$0.5', '$0.7', 'b']
    del b                                    []
    del $0.5                                 []
    c = $0.7                                 ['$0.7', 'c']
    del $0.7                                 []
    $0.9 = cast(value=c)                     ['$0.9', 'c']
    del c                                    []
    return $0.9                              ['$0.9']
----------------------------nopython: consts_add_one----------------------------
label 0:
    x = arg(0, name=x)                       ['x']
    $const0.1 = const(int, 11)               ['$const0.1']
    a = $const0.1                            ['$const0.1', 'a']
    del $const0.1                            []
    $const0.2 = const(float, 21.2)           ['$const0.2']
    b = $const0.2                            ['$const0.2', 'b']
    del $const0.2                            []
    $0.5 = x + a                             ['$0.5', 'a', 'x']
    del x                                    []
    del a                                    []
    $0.7 = $0.5 + b                          ['$0.5', '$0.7', 'b']
    del b                                    []
    del $0.5                                 []
    c = $0.7                                 ['$0.7', 'c']
    del $0.7                                 []
    $0.9 = cast(value=c)                     ['$0.9', 'c']
    del c                                    []
    return $0.9                              ['$0.9']

请注意 const 节点中值的变化。

通道执行时间

Numba 内置支持对所有编译器通道进行计时,执行时间存储在与编译结果关联的元数据中。这演示了基于先前定义的函数 foo 访问此信息的一种方法:

compile_result = foo.overloads[foo.signatures[0]]
nopython_times = compile_result.metadata['pipeline_times']['nopython']
for k in nopython_times.keys():
    if ConstsAddOne._name in k:
        print(nopython_times[k])

其输出例如为:

pass_timings(init=1.914000677061267e-06, run=4.308700044930447e-05, finalize=1.7400006981915794e-06)

这显示了通道的初始化、运行和最终化时间(以秒为单位)。