NBEP 2: 扩展点

作者

Antoine Pitrou

日期

2015 年 7 月

状态

草稿

在 Numba 中实现新类型或函数需要连接到编译链中(以及可能在链外)的各种机制。本文档首先旨在检查当前的方法,其次提出使扩展更容易的建议。

如果其中一些提案得以实施,我们应该首先努力在内部使用和实践它们,然后再向公众公开 API。

注意

本文档不涵盖 CUDA 或任何其他非 CPU 后端。

高级 API

目前没有高级 API,使得某些用例比应有的更复杂。

拟议的更改

专用模块

我们提议添加一个 numba.extending 模块,暴露用于扩展 Numba 的主要 API。

实现函数

我们提议添加一个 @overload 装饰器,允许实现在 nopython 模式下使用的给定函数。重载函数具有与实现函数相同的形式签名,并接收实际的参数类型。它应该返回一个实现给定类型重载函数的 Python 函数。

以下示例使用此方法实现了 numpy.where()

import numpy as np

from numba.core import types
from numba.extending import overload

@overload(np.where)
def where(cond, x, y):
    """
    Implement np.where().
    """
    # Choose implementation based on argument types.
    if isinstance(cond, types.Array):
        # Array where() => return an array of the same shape
        if all(ty.layout == 'C' for ty in (cond, x, y)):
            def where_impl(cond, x, y):
                """
                Fast implementation for C-contiguous arrays
                """
                shape = cond.shape
                if x.shape != shape or y.shape != shape:
                    raise ValueError("all inputs should have the same shape")
                res = np.empty_like(x)
                cf = cond.flat
                xf = x.flat
                yf = y.flat
                rf = res.flat
                for i in range(cond.size):
                    rf[i] = xf[i] if cf[i] else yf[i]
                return res
        else:
            def where_impl(cond, x, y):
                """
                Generic implementation for other arrays
                """
                shape = cond.shape
                if x.shape != shape or y.shape != shape:
                    raise ValueError("all inputs should have the same shape")
                res = np.empty_like(x)
                for idx, c in np.ndenumerate(cond):
                    res[idx] = x[idx] if c else y[idx]
                return res

    else:
        def where_impl(cond, x, y):
            """
            Scalar where() => return a 0-dim array
            """
            scal = x if cond else y
            return np.full_like(scal, scal)

    return where_impl

也可以实现 Numba 已知的函数,以支持额外的类型。以下示例使用此方法为元组实现了内置函数 len()

@overload(len)
def tuple_len(x):
   if isinstance(x, types.BaseTuple):
      # The tuple length is known at compile-time, so simply reify it
      # as a constant.
      n = len(x)
      def len_impl(x):
         return n
      return len_impl

实现属性

我们提议添加一个 @overload_attribute 装饰器,允许实现在 nopython 模式下使用的属性获取器。

以下示例实现了 Numpy 数组上的 .nbytes 属性

@overload_attribute(types.Array, 'nbytes')
def array_nbytes(arr):
   def get(arr):
       return arr.size * arr.itemsize
   return get

注意

overload_attribute() 签名允许扩展以定义设置器(setter)和删除器(deleter),通过让被装饰函数返回一个 getter, setter, deleter 元组,而不是单个的 getter

实现方法

我们提议添加一个 @overload_method 装饰器,允许实现在 nopython 模式下使用的实例方法。

以下示例实现了 Numpy 数组上的 .take() 方法

@overload_method(types.Array, 'take')
def array_take(arr, indices):
   if isinstance(indices, types.Array):
       def take_impl(arr, indices):
           n = indices.shape[0]
           res = np.empty(n, arr.dtype)
           for i in range(n):
               res[i] = arr[indices[i]]
           return res
       return take_impl

暴露结构体成员

我们提议添加一个 make_attribute_wrapper() 函数,将内部字段暴露为可见的只读属性,适用于那些由 StructModel 数据模型支持的类型。

例如,假设 PdIndexType 是 pandas 索引的 Numba 类型,以下是如何将底层 Numpy 数组暴露为 ._data 属性

@register_model(PdIndexType)
class PdIndexModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [
            ('values', fe_type.as_array),
            ]
        models.StructModel.__init__(self, dmm, fe_type, members)

make_attribute_wrapper(PdIndexType, 'values', '_data')

类型系统

Numba 类型

Numba 的标准类型在 numba.types 中声明。要声明新类型,可以继承基础 Type 类或其现有抽象子类,并实现所需的功能。

拟议的更改

无需更改。

值的类型推断

新类型的值如果可以作为函数参数或常量出现,则需要进行类型推断。核心机制在 numba.typing.typeof 中。

在某些 Python 类或类独占映射到新类型的常见情况下,可以扩展一个泛型函数以根据这些类进行调度,例如:

from numba.typing.typeof import typeof_impl

@typeof_impl(MyClass)
def _typeof_myclass(val, c):
   if "some condition":
      return MyType(...)

typeof_impl 特化必须返回一个 Numba 类型实例,如果值类型推断失败则返回 None。

(当控制被推断类型的类时,typeof_impl 的一个替代方法是在类上定义一个 _numba_type_ 属性)

在新类型可以表示无法枚举的各种 Python 类的罕见情况下,必须在 typeof_impl 泛型函数的回退实现中插入手动检查。

拟议的更改

允许人们定义一个泛型钩子,而无需对回退实现进行猴子补丁。

函数参数类型推断的快速路径

(可选)可能希望允许新类型参与快速类型解析(用 C 代码编写),以最大限度地减少 JIT 编译函数使用新类型调用时的函数调用开销。此时必须在 _typeof.c 文件中插入必要的检查和实现,可能在 compute_fingerprint() 函数内部。

拟议的更改

无。在 C Python 扩展中嵌入 C 代码的泛型钩子添加过于复杂。

操作的类型推断

各种操作(函数调用、运算符等)产生的值使用一组称为“模板”的辅助工具进行类型推断。可以通过继承现有基类之一来定义新模板,并实现所需的推断机制。模板使用装饰器显式注册到类型推断机制中。

ConcreteTemplate 基类允许将推断定义为给定操作的一组支持签名。以下示例为模运算符进行类型推断

@builtin
class BinOpMod(ConcreteTemplate):
    key = "%"
    cases = [signature(op, op, op)
             for op in sorted(types.signed_domain)]
    cases += [signature(op, op, op)
              for op in sorted(types.unsigned_domain)]
    cases += [signature(op, op, op) for op in sorted(types.real_domain)]

(请注意,签名中使用的是类型实例,严重限制了可表达的泛型性)

AbstractTemplate 基类允许以编程方式定义推断,赋予其完全的灵活性。这是一个关于如何表达元组索引(即 __getitem__ 运算符)的简化示例

@builtin
class GetItemUniTuple(AbstractTemplate):
    key = "getitem"

    def generic(self, args, kws):
        tup, idx = args
        if isinstance(tup, types.UniTuple) and isinstance(idx, types.Integer):
            return signature(tup.dtype, tup, idx)

AttributeTemplate 基类允许为给定类型的属性和方法进行类型推断。这是一个示例,为复数的 .real.imag 属性进行类型推断

@builtin_attr
class ComplexAttribute(AttributeTemplate):
    key = types.Complex

    def resolve_real(self, ty):
        return ty.underlying_float

    def resolve_imag(self, ty):
        return ty.underlying_float

注意

AttributeTemplate 仅适用于获取属性。设置属性值在 numba.typeinfer 中是硬编码的。

CallableTemplate 基类提供了一种更简单的方法来解析灵活的函数签名,通过允许定义一个与被推断类型的函数具有相同定义的 Callable。例如,如果 Numba 支持列表,就可以假定为 Python 的 sorted 函数进行类型推断

@builtin
class Sorted(CallableTemplate):
    key = sorted

    def generic(self):
        def typer(iterable, key=None, reverse=None):
            if reverse is not None and not isinstance(reverse, types.Boolean):
                return
            if key is not None and not isinstance(key, types.Callable):
                return
            if not isinstance(iterable, types.Iterable):
                return
            return types.List(iterable.iterator_type.yield_type)

        return typer

(请注意,您可以只返回函数的返回类型,而不是完整的签名)

拟议的更改

各种装饰器的命名相当模糊和混乱。我们提议将 @builtin 重命名为 @infer,将 @builtin_attr 重命名为 @infer_getattr,并将 builtin_global 重命名为 infer_global

全局值的两步声明有些冗长,我们提议通过允许将 infer_global 用作装饰器来简化它

@infer_global(len)
class Len(AbstractTemplate):
    key = len

    def generic(self, args, kws):
        assert not kws
        (val,) = args
        if isinstance(val, (types.Buffer, types.BaseTuple)):
            return signature(types.intp, val)

基于类的 API 可能感觉笨拙,我们可以为某些模板类型添加一个函数式 API

@type_callable(sorted)
def type_sorted(context):
    def typer(iterable, key=None, reverse=None):
        # [same function as above]

    return typer

代码生成

Numba 类型值的具体表示

任何具体的 Numba 类型都必须能够以 LLVM 形式表示(用于变量存储、参数传递等)。可以通过实现一个数据模型类并使用装饰器注册它来定义该表示。标准类型的数据模型类在 numba.datamodel.models 中定义。

拟议的更改

无需更改。

类型转换

Numba 类型之间的隐式转换当前在 BaseContext.cast() 方法中实现为一系列单一的选择和类型检查。要添加新的隐式转换,可以在该方法中附加一个类型特定的检查。

布尔求值是隐式转换的一种特殊情况(目标类型是 types.Boolean)。

注意

显式转换被视为常规操作,例如构造函数调用。

拟议的更改

添加一个用于隐式转换的泛型函数,基于源类型和目标类型的多重分派。这是一个示例,展示如何编写浮点数到整数的转换

@lower_cast(types.Float, types.Integer)
def float_to_integer(context, builder, fromty, toty, val):
    lty = context.get_value_type(toty)
    if toty.signed:
        return builder.fptosi(val, lty)
    else:
        return builder.fptoui(val, lty)

操作的实现

其他操作通过一组泛型函数和装饰器实现并注册。例如,以下是如何在 Numpy 数组上实现 .ndim 属性查找的方法

@builtin_attr
@impl_attribute(types.Kind(types.Array), "ndim", types.intp)
def array_ndim(context, builder, typ, value):
    return context.get_constant(types.intp, typ.ndim)

以下是如何在元组值上实现 len() 调用

@builtin
@implement(types.len_type, types.Kind(types.BaseTuple))
def tuple_len(context, builder, sig, args):
    tupty, = sig.args
    retty = sig.return_type
    return context.get_constant(retty, len(tupty.types))

拟议的更改

审查并精简 API。取消显式编写 types.Kind(...) 的要求。移除独立的 @implement 装饰器,并将 @builtin 重命名为 @lower_builtin,将 @builtin_attr 重命名为 @lower_getattr 等。

添加装饰器以实现 setattr() 操作,命名为 @lower_setattr@lower_setattr_generic

从/到 Python 对象的转换

某些类型需要从 Python 对象转换或转换为 Python 对象,如果它们可以作为函数参数传递或从函数返回。相应的装箱和拆箱操作使用泛型函数实现。标准 Numba 类型的实现位于 numba.targets.boxing 中。例如,以下是布尔值的装箱实现

@box(types.Boolean)
def box_bool(c, typ, val):
    longval = c.builder.zext(val, c.pyapi.long)
    return c.pyapi.bool_from_long(longval)

拟议的更改

将实现签名从 (c, typ, val) 更改为 (typ, val, c),以匹配 typeof_impl 泛型函数所选择的签名。