使用 Numba 重写通道实现乐趣与优化

概述

本节介绍中间表示 (IR) 重写,以及如何利用它们来实现优化。

如前所述,在“阶段 5a: 重写类型化 IR”中,重写 Numba IR 使我们能够执行在较低的 LLVM 级别上难以执行的优化。与 Numba 类型和降级子系统类似,重写子系统是用户可扩展的。这种可扩展性使 Numba 能够支持各种领域特定优化 (DSO)。

其余小节详细介绍了实现重写、向重写注册表注册重写的机制,并提供了添加新重写的示例,以及数组表达式优化通道的内部原理。最后,我们将回顾示例中展示的一些用例,并审视开发人员应注意的任何要点。

重写通道

重写通道具有简单的 match()apply() 接口。匹配和重写之间的划分遵循了在声明式领域特定语言 (DSL) 中定义术语重写的方式。在此类 DSL 中,可以按如下方式编写重写

<match> => <replacement>

<match><replacement> 符号表示 IR 术语表达式,其中左侧呈现要匹配的模式,右侧呈现一个 IR 术语构造函数,用于在匹配时构建。每当重写匹配到 IR 模式时,左侧的任何自由变量都会在自定义环境中绑定。应用时,重写使用模式匹配环境来绑定右侧的任何自由变量。

由于 Python 通常不以声明性方式使用,Numba 使用对象状态来处理匹配和应用步骤之间的信息传输。

Rewrite” 基类

Rewrite

Rewrite 类仅仅定义了一个用于 Numba 重写的抽象基类。开发人员应将重写定义为该基类的子类,并重载 match()apply() 方法。

pipeline

pipeline 属性包含正在编译要重写的函数的 numba.compiler.Pipeline 实例。

__init__(self, pipeline, *args, **kws)

重写的基构造函数只是将其参数存储到同名属性中。除非用于调试或测试,否则重写应仅由 RewriteRegistryRewriteRegistry.apply() 方法中构造,并且构造接口应保持稳定(尽管 pipeline 通常会包含所有已知信息)。

match(self, block, typemap, callmap)

match() 方法除了 self 外,还接受四个参数

  • func_ir: 这是 numba.ir.FunctionIR 的实例,表示正在被重写的函数。

  • block: 这是 numba.ir.Block 的实例。匹配方法应该遍历包含在 numba.ir.Block.body 成员中的指令。

  • typemap: 这是一个 Python dict 实例,将 IR 中的符号名称(表示为字符串)映射到 Numba 类型。

  • callmap: 这是另一个 dict 实例,将调用(表示为 numba.ir.Expr 实例)映射到其对应的调用点类型签名(表示为 numba.typing.templates.Signature 实例)。

match() 方法应返回一个 bool 结果。返回 True 表示找到一个或多个匹配项,并且 apply() 方法将返回一个新的替换 numba.ir.Block 实例。返回 False 表示未找到任何匹配项,并且随后对 apply() 的调用将返回未定义或无效的结果。

apply(self)

apply() 方法应仅在成功调用 match() 之后调用。此方法除了 self 外不接受其他参数,并且应返回一个替换 numba.ir.Block 实例。

如上所述,调用 apply() 的行为是未定义的,除非 match() 已经被调用并返回了 True

子类化 Rewrite

在深入探讨任何 Rewrite 子类必须拥有的重载方法的期望之前,让我们退一步来回顾一下这里正在发生什么。通过提供可扩展的编译器,Numba 允许用户自定义代码生成器,这些生成器可能不完整,甚至不正确。当代码生成器出现问题时,它可能导致程序行为异常或提前终止。用户定义的重写增加了新的复杂性,因为它们不仅必须生成正确的代码,而且它们生成的代码应确保编译器不会陷入匹配/应用循环。编译器的非终止将直接导致用户函数调用的非终止。

有几种方法可以帮助确保重写终止

  • 类型化: 重写通常应尝试分解复合类型,并避免组合新类型。如果重写匹配特定类型,将表达式类型更改为较低级别的类型将确保在应用重写后它们不再匹配。

  • 特殊指令: 重写可以在目标 IR 中合成自定义运算符或使用特殊函数。这种技术再次生成了不再属于原始匹配范围的代码,因此重写将终止。

在下面的“案例研究: 数组表达式”小节中,我们将看到数组表达式重写器如何使用这两种技术。

重载 Rewrite.match()

每个重写开发人员都应力求使其 match() 实现尽快返回 False 值。Numba 是一个即时编译器,增加编译时间最终会增加用户的运行时间。当重写对给定块返回 False 时,注册表将不再使用该重写处理该块,编译器也更接近于执行降级操作。

这种对及时性的需求必须与收集进行重写匹配所需的信息相平衡。重写开发人员应该乐于向其子类添加动态属性,然后让这些新属性指导替换基本块的构建。

重载 Rewrite.apply()

apply() 方法应返回一个替换 numba.ir.Block 实例,以替换包含重写匹配项的基本块。如上所述,由 apply() 方法构建的 IR 应该保留用户代码的语义,但也应避免为相同的重写或一组重写生成另一个匹配项。

重写注册表

当您想在重写通道中包含一个重写时,应将其注册到重写注册表。numba.rewrites 模块提供了抽象基类和类装饰器,用于接入 Numba 重写子系统。以下是一个新重写的存根定义示例

from numba import rewrites

@rewrites.register_rewrite
class MyRewrite(rewrites.Rewrite):

    def match(self, block, typemap, calltypes):
        raise NotImplementedError("FIXME")

    def apply(self):
        raise NotImplementedError("FIXME")

开发人员应注意,如上所示使用类装饰器将在导入时注册重写。开发人员有责任确保其扩展在编译开始之前加载。

案例研究: 数组表达式

本小节更深入地探讨了数组表达式重写器。数组表达式重写器及其大部分支持功能都在 numba.npyufunc.array_exprs 模块中。重写通道本身在 RewriteArrayExprs 类中实现。除了重写器之外,array_exprs 模块还包含一个用于降低数组表达式的函数,即 _lower_array_expr()。总体优化过程如下

  • RewriteArrayExprs.match(): 重写通道查找构成数组表达式的一个或多个数组操作。

  • RewriteArrayExprs.apply(): 一旦找到数组表达式,重写器就会将单个数组操作替换为一种新的 IR 表达式,即 arrayexpr

  • numba.npyufunc.array_exprs._lower_array_expr(): 在降级过程中,当代码生成器遇到 arrayexpr IR 表达式时,它会调用 _lower_array_expr()

有关优化每个步骤的更多详细信息如下。

RewriteArrayExprs.match() 方法

数组表达式优化通道首先寻找数组操作,包括对受支持的 ufunc 和用户定义的 DUFunc 的调用。Numba IR 遵循静态单赋值 (SSA) 语言的约定,这意味着数组操作符的搜索从寻找赋值指令开始。

当重写通道调用 RewriteArrayExprs.match() 方法时,它首先检查是否可以简单地拒绝该基本块。如果该方法确定该块是匹配的候选块,它会在重写对象中设置以下状态变量

  • crnt_block: 当前正在匹配的基本块。

  • typemap: 正在匹配的函数的 typemap

  • matches: 引用数组表达式的变量名列表。

  • array_assigns: 一个映射,将赋值变量名映射到定义给定变量的实际赋值指令。

  • const_assigns: 一个映射,将赋值变量名映射到定义常量变量的常量值表达式。

此时,匹配方法遍历输入基本块中的赋值指令。对于每个赋值指令,匹配器查找以下两种情况之一

  • 数组操作: 如果赋值指令的右侧是一个表达式,并且该表达式的结果是数组类型,则匹配器会检查该表达式是已知的数组操作,还是对通用函数的调用。如果找到数组操作符,匹配器会将左侧变量名和整个指令存储在 array_assigns 成员中。最后,匹配器检查数组操作的任何操作数是否也被标识为其他数组操作的目标。如果一个或多个操作数也是数组操作的目标,则匹配器还会将左侧变量名附加到 matches 成员中。

  • 常量: 常量(甚至是标量)可以是数组操作的操作数。无需担心常量是否是数组表达式的一部分,匹配器会将常量名称和值存储在 const_assigns 成员中。

匹配方法的最后简单地检查 matches 列表是否非空,如果存在一个或多个匹配项,则返回 True;如果 matches 为空,则返回 False

RewriteArrayExprs.apply() 方法

RewriteArrayExprs.match() 找到一个或多个匹配的数组表达式时,重写通道将调用 RewriteArrayExprs.apply()。apply 方法分两个通道工作。第一个通道遍历找到的匹配项,并构建一个从旧基本块中的指令到新基本块中的新指令的映射。第二个通道遍历旧基本块中的指令,复制未因重写而更改的指令,并替换或删除第一个通道识别出的指令。

RewriteArrayExprs._handle_matches() 实现了重写代码生成部分的第一遍。对于每个匹配项,此方法都会构建一个特殊的 IR 表达式,其中包含数组表达式的表达式树。为了计算表达式树的叶子节点,_handle_matches() 方法会遍历已识别的根操作的操作数。如果操作数是另一个数组操作,则会将其转换为表达式子树。如果操作数是常量,_handle_matches() 会复制常量值。否则,该操作数被标记为由数组表达式使用。当方法构建数组表达式节点时,它会构建一个从旧指令到新指令的映射 (replace_map),以及可能已移动的变量集 (used_vars) 和应完全删除的变量 (dead_vars) 集。这三个数据结构会返回给调用 RewriteArrayExprs.apply() 方法的函数。

RewriteArrayExprs.apply() 方法的其余部分遍历旧基本块中的指令。对于每条指令,此方法会根据 RewriteArrayExprs._handle_matches() 的结果替换、删除或复制该指令。以下列表描述了优化如何处理单个指令

  • 当指令是赋值指令时,apply() 会检查它是否在替换指令映射中。当在指令映射中找到赋值指令时,apply() 必须接着检查替换指令是否也在替换映射中。优化器会继续此检查,直到遇到 None 值或不在替换映射中的指令。替换为 None 的指令将被删除。具有非 None 替换的指令将被替换。不在替换映射中的赋值指令会原样附加到新的基本块中。

  • 当指令是删除指令时,重写会检查它是否删除了可能仍被后续数组表达式使用的变量,或者它是否删除了一个死变量。用于已使用变量的删除指令会添加到延迟删除指令映射中,apply() 使用此映射将它们移到该变量的任何使用之后。循环会复制非死变量的删除指令,并忽略死变量的删除指令(有效地将它们从基本块中移除)。

  • 所有其他指令都附加到新的基本块中。

最后,apply() 方法返回用于降级的新基本块。

_lower_array_expr() 函数

如果仅仅停留在重写阶段,那么编译器的降级阶段将会失败,抱怨它不知道如何降级 arrayexpr 操作。我们首先通过在编译器实例化 RewriteArrayExprs 类时,将一个降级函数挂接到目标上下文中。这个挂接会使降级通道在遇到 arrayexr 操作符时调用 _lower_array_expr()

此函数包含两个步骤

  • 合成一个实现数组表达式的 Python 函数: 这个新的 Python 函数本质上就像一个 Numpy ufunc,返回广播数组参数中标量值的表达式结果。降级函数通过将数组表达式树转换为 Python AST 来实现这一点。

  • 将合成的 Python 函数编译成内核: 此时,降级函数依赖于现有的用于降级 ufunc 和 DUFunc 内核的代码,在定义如何降级对合成函数的调用后,会调用 numba.targets.numpyimpl.numpy_ufunc_kernel()

最终结果类似于 Numba 对象模式中的循环提升。

结论与注意事项

我们已经了解了如何在 Numba 中实现重写,从接口开始,到实际的优化结束。本节的要点是

  • 在编写一个好的插件时,匹配器应尽量尽快得到一个通过/不通过的结果。

  • 重写应用部分可能计算成本更高,但仍应生成不会导致编译器陷入无限循环的代码。

  • 我们使用对象状态来将匹配结果传递给重写应用通道。