内联注意事项

有时,在 Numba IR 表示层面,能够将函数在其调用点内联化会非常有用。诸如 numba.jit()numba.extending.overload()register_jitable() 等装饰器支持关键字参数 inline,以实现此行为。

尝试在此级别进行内联时,了解其目的和影响非常重要。与 LLVM 为提高性能而进行的内联不同,在 Numba IR 级别进行内联的主要原因是为了允许类型推断跨越函数边界。

例如,考虑以下代码片段

from numba import njit


@njit
def bar(a):
    a.append(10)


@njit
def foo():
    z = []
    bar(z)


foo()

这将无法编译和运行,因为 z 的类型无法推断,因为它只会在 bar 中被细化。如果现在我们为 bar 的装饰器添加 inline=True,该代码片段将编译并运行。这是因为内联对 a.append(10) 的调用将意味着 z 将被细化为保存整数,从而类型推断将成功。

因此,总而言之,在 Numba IR 级别进行内联不太可能带来性能优势。而在 LLVM 级别进行内联则更有可能。

关键字参数 inline 可以是以下三个值之一:

  • 字符串 'never',这是默认值,表示函数在任何情况下都不会被内联。

  • 字符串 'always',这导致函数在所有调用点都被内联。

  • 一个接受三个参数的 Python 函数。第一个参数始终是请求内联的 callir.Expr 节点,它的存在是为了让函数能够做出与调用上下文相关的决策。第二个和第三个参数是:

    • 在无类型内联的情况下(即使用 numba.jit() 系列装饰器时发生的内联),两个参数都是 numba.ir.FunctionIR 实例。第二个参数对应于调用者的 IR,第三个参数对应于被调用者的 IR。

    • 在有类型内联的情况下(即使用 numba.extending.overload() 时发生的内联),两个参数都是一个 namedtuple 的实例,其字段(对应于它们在编译器内部的标准用法)为:

      • func_ir - 函数的 Numba IR。

      • typemap - 函数的类型映射。

      • calltypes - 函数中任何调用的调用类型。

      • signature - 函数的签名。

      第二个参数保存调用者的信息,第三个参数保存被调用者的信息。

    在所有情况下,函数应返回 True 以进行内联,返回 False 则不进行内联,这实际上允许自定义内联规则(典型用途可能是成本模型)。

  • 使用 inline='always' 的递归函数将导致编译无法终止。如果您希望避免这种情况,请提供一个函数来限制递归深度(见下文)。

注意

不保证函数评估内联的顺序或内联的顺序。

使用 numba.jit() 的示例

一个在 numba.njit() 装饰器中使用所有三种 inline 选项的示例

from numba import njit
import numba
from numba.core import ir


@njit(inline='never')
def never_inline():
    return 100


@njit(inline='always')
def always_inline():
    return 200


def sentinel_cost_model(expr, caller_info, callee_info):
    # this cost model will return True (i.e. do inlining) if either:
    # a) the callee IR contains an `ir.Const(37)`
    # b) the caller IR contains an `ir.Const(13)` logically prior to the call
    #    site

    # check the callee
    for blk in callee_info.blocks.values():
        for stmt in blk.body:
            if isinstance(stmt, ir.Assign):
                if isinstance(stmt.value, ir.Const):
                    if stmt.value.value == 37:
                        return True

    # check the caller
    before_expr = True
    for blk in caller_info.blocks.values():
        for stmt in blk.body:
            if isinstance(stmt, ir.Assign):
                if isinstance(stmt.value, ir.Expr):
                    if stmt.value == expr:
                        before_expr = False
                if isinstance(stmt.value, ir.Const):
                    if stmt.value.value == 13:
                        return True & before_expr
    return False


@njit(inline=sentinel_cost_model)
def maybe_inline1():
    # Will not inline based on the callee IR with the declared cost model
    # The following is ir.Const(300).
    return 300


@njit(inline=sentinel_cost_model)
def maybe_inline2():
    # Will inline based on the callee IR with the declared cost model
    # The following is ir.Const(37).
    return 37


@njit
def foo():
    a = never_inline()  # will never inline
    b = always_inline()  # will always inline

    # will not inline as the function does not contain a magic constant known to
    # the cost model, and the IR up to the call site does not contain a magic
    # constant either
    d = maybe_inline1()

    # declare this magic constant to trigger inlining of maybe_inline1 in a
    # subsequent call
    magic_const = 13

    # will inline due to above constant declaration
    e = maybe_inline1()

    # will inline as the maybe_inline2 function contains a magic constant known
    # to the cost model
    c = maybe_inline2()

    return a + b + c + d + e + magic_const


foo()

执行时产生以下内容(在合法化阶段后打印 IR,通过环境变量 NUMBA_DEBUG_PRINT_AFTER="ir_legalization" 启用)

label 0:
    $0.1 = global(never_inline: CPUDispatcher(<function never_inline at 0x7f890ccf9048>)) ['$0.1']
    $0.2 = call $0.1(func=$0.1, args=[], kws=(), vararg=None) ['$0.1', '$0.2']
    del $0.1                                 []
    a = $0.2                                 ['$0.2', 'a']
    del $0.2                                 []
    $0.3 = global(always_inline: CPUDispatcher(<function always_inline at 0x7f890ccf9598>)) ['$0.3']
    del $0.3                                 []
    $const0.1.0 = const(int, 200)            ['$const0.1.0']
    $0.2.1 = $const0.1.0                     ['$0.2.1', '$const0.1.0']
    del $const0.1.0                          []
    $0.4 = $0.2.1                            ['$0.2.1', '$0.4']
    del $0.2.1                               []
    b = $0.4                                 ['$0.4', 'b']
    del $0.4                                 []
    $0.5 = global(maybe_inline1: CPUDispatcher(<function maybe_inline1 at 0x7f890ccf9ae8>)) ['$0.5']
    $0.6 = call $0.5(func=$0.5, args=[], kws=(), vararg=None) ['$0.5', '$0.6']
    del $0.5                                 []
    d = $0.6                                 ['$0.6', 'd']
    del $0.6                                 []
    $const0.7 = const(int, 13)               ['$const0.7']
    magic_const = $const0.7                  ['$const0.7', 'magic_const']
    del $const0.7                            []
    $0.8 = global(maybe_inline1: CPUDispatcher(<function maybe_inline1 at 0x7f890ccf9ae8>)) ['$0.8']
    del $0.8                                 []
    $const0.1.2 = const(int, 300)            ['$const0.1.2']
    $0.2.3 = $const0.1.2                     ['$0.2.3', '$const0.1.2']
    del $const0.1.2                          []
    $0.9 = $0.2.3                            ['$0.2.3', '$0.9']
    del $0.2.3                               []
    e = $0.9                                 ['$0.9', 'e']
    del $0.9                                 []
    $0.10 = global(maybe_inline2: CPUDispatcher(<function maybe_inline2 at 0x7f890ccf9b70>)) ['$0.10']
    del $0.10                                []
    $const0.1.4 = const(int, 37)             ['$const0.1.4']
    $0.2.5 = $const0.1.4                     ['$0.2.5', '$const0.1.4']
    del $const0.1.4                          []
    $0.11 = $0.2.5                           ['$0.11', '$0.2.5']
    del $0.2.5                               []
    c = $0.11                                ['$0.11', 'c']
    del $0.11                                []
    $0.14 = a + b                            ['$0.14', 'a', 'b']
    del b                                    []
    del a                                    []
    $0.16 = $0.14 + c                        ['$0.14', '$0.16', 'c']
    del c                                    []
    del $0.14                                []
    $0.18 = $0.16 + d                        ['$0.16', '$0.18', 'd']
    del d                                    []
    del $0.16                                []
    $0.20 = $0.18 + e                        ['$0.18', '$0.20', 'e']
    del e                                    []
    del $0.18                                []
    $0.22 = $0.20 + magic_const              ['$0.20', '$0.22', 'magic_const']
    del magic_const                          []
    del $0.20                                []
    $0.23 = cast(value=$0.22)                ['$0.22', '$0.23']
    del $0.22                                []
    return $0.23                             ['$0.23']

上述内容需注意:

  1. 对函数 never_inline 的调用仍然保留为调用。

  2. always_inline 函数已被内联,请注意其在调用者主体中的 const(int, 200)

  3. const(int, 13) 声明之前,有一个对 maybe_inline1 的调用,成本模型阻止了其内联。

  4. const(int, 13) 之后,对 maybe_inline1 的后续调用已被内联,如调用者主体中的 const(int, 300) 所示。

  5. 函数 maybe_inline2 已被内联,如调用者主体中的 const(int, 37) 所示。

  6. 死代码消除尚未执行,因此 IR 中存在多余的语句。

使用 numba.extending.overload() 的示例

一个将内联与 numba.extending.overload() 装饰器结合使用的示例。最有趣的是,如果将一个函数作为参数提供给 inline,则可以通过提供的函数参数获取更多信息,用于决策制定。此外,不同的 @overload 可以具有不同的内联行为,有多种方法可以实现这一点。

import numba
from numba.extending import overload
from numba import njit, types


def bar(x):
    """A function stub to overload"""
    pass


@overload(bar, inline='always')
def ol_bar_tuple(x):
    # An overload that will always inline, there is a type guard so that this
    # only applies to UniTuples.
    if isinstance(x, types.UniTuple):
        def impl(x):
            return x[0]
        return impl


def cost_model(expr, caller, callee):
    # Only inline if the type of the argument is an Integer
    return isinstance(caller.typemap[expr.args[0].name], types.Integer)


@overload(bar, inline=cost_model)
def ol_bar_scalar(x):
    # An overload that will inline based on a cost model, it only applies to
    # scalar values in the numerical domain as per the type guard on Number
    if isinstance(x, types.Number):
        def impl(x):
            return x + 1
        return impl


@njit
def foo():

    # This will resolve via `ol_bar_tuple` as the argument is a types.UniTuple
    # instance. It will always be inlined as specified in the decorator for this
    # overload.
    a = bar((1, 2, 3))

    # This will resolve via `ol_bar_scalar` as the argument is a types.Number
    # instance, hence the cost_model will be used to determine whether to
    # inline.
    # The function will be inlined as the value 100 is an IntegerLiteral which
    # is an instance of a types.Integer as required by the cost_model function.
    b = bar(100)

    # This will also resolve via `ol_bar_scalar` as the argument is a
    # types.Number instance, again the cost_model will be used to determine
    # whether to inline.
    # The function will not be inlined as the complex value is not an instance
    # of a types.Integer as required by the cost_model function.
    c = bar(300j)

    return a + b + c


foo()

执行时产生以下内容(在合法化阶段后打印 IR,通过环境变量 NUMBA_DEBUG_PRINT_AFTER="ir_legalization" 启用)

label 0:
    $const0.2 = const(tuple, (1, 2, 3))      ['$const0.2']
    x.0 = $const0.2                          ['$const0.2', 'x.0']
    del $const0.2                            []
    $const0.2.2 = const(int, 0)              ['$const0.2.2']
    $0.3.3 = getitem(value=x.0, index=$const0.2.2) ['$0.3.3', '$const0.2.2', 'x.0']
    del x.0                                  []
    del $const0.2.2                          []
    $0.4.4 = $0.3.3                          ['$0.3.3', '$0.4.4']
    del $0.3.3                               []
    $0.3 = $0.4.4                            ['$0.3', '$0.4.4']
    del $0.4.4                               []
    a = $0.3                                 ['$0.3', 'a']
    del $0.3                                 []
    $const0.5 = const(int, 100)              ['$const0.5']
    x.5 = $const0.5                          ['$const0.5', 'x.5']
    del $const0.5                            []
    $const0.2.7 = const(int, 1)              ['$const0.2.7']
    $0.3.8 = x.5 + $const0.2.7               ['$0.3.8', '$const0.2.7', 'x.5']
    del x.5                                  []
    del $const0.2.7                          []
    $0.4.9 = $0.3.8                          ['$0.3.8', '$0.4.9']
    del $0.3.8                               []
    $0.6 = $0.4.9                            ['$0.4.9', '$0.6']
    del $0.4.9                               []
    b = $0.6                                 ['$0.6', 'b']
    del $0.6                                 []
    $0.7 = global(bar: <function bar at 0x7f6c3710d268>) ['$0.7']
    $const0.8 = const(complex, 300j)         ['$const0.8']
    $0.9 = call $0.7($const0.8, func=$0.7, args=[Var($const0.8, inline_overload_example.py (56))], kws=(), vararg=None) ['$0.7', '$0.9', '$const0.8']
    del $const0.8                            []
    del $0.7                                 []
    c = $0.9                                 ['$0.9', 'c']
    del $0.9                                 []
    $0.12 = a + b                            ['$0.12', 'a', 'b']
    del b                                    []
    del a                                    []
    $0.14 = $0.12 + c                        ['$0.12', '$0.14', 'c']
    del c                                    []
    del $0.12                                []
    $0.15 = cast(value=$0.14)                ['$0.14', '$0.15']
    del $0.14                                []
    return $0.15                             ['$0.15']

上述内容需注意:

  1. 第一个高亮部分是针对 UniTuple 参数类型的总是内联的重载。

  2. 第二个高亮部分是针对 Number 参数类型的重载,它已被内联,因为成本模型函数决定如此,因为参数是 Integer 类型实例。

  3. 第三个高亮部分是针对 Number 参数类型的重载,它未被内联,因为成本模型函数决定拒绝它,因为参数是 Complex 类型实例。

  4. 死代码消除尚未执行,因此 IR 中存在多余的语句。

使用函数限制递归函数内联深度

在使用递归内联时,可以通过使用成本模型来终止编译。

from numba import njit
import numpy as np

class CostModel(object):
    def __init__(self, max_inlines):
        self._count = 0
        self._max_inlines = max_inlines

    def __call__(self, expr, caller, callee):
        ret = self._count < self._max_inlines
        self._count += 1
        return ret

@njit(inline=CostModel(3))
def factorial(n):
    if n <= 0:
        return 1
    return n * factorial(n - 1)

factorial(5)