示例:区间类型

在此示例中,我们将扩展 Numba 前端,以添加对 Numba 内部不支持的用户定义类的支持。这将允许

  • 将类的实例传递给 Numba 函数

  • 在 Numba 函数中访问类的属性

  • 从 Numba 函数构造并返回该类的新实例

(以上所有操作均在 nopython 模式下)

我们将混合使用 高级扩展 API低级扩展 API,具体取决于给定任务可用的 API。

我们示例的起点是以下纯 Python 类

class Interval(object):
    """
    A half-open interval on the real number line.
    """
    def __init__(self, lo, hi):
        self.lo = lo
        self.hi = hi

    def __repr__(self):
        return 'Interval(%f, %f)' % (self.lo, self.hi)

    @property
    def width(self):
        return self.hi - self.lo

扩展类型层

创建新的 Numba 类型

由于 Numba 不认识 Interval 类,我们必须创建一个新的 Numba 类型来表示其实例。Numba 不直接处理 Python 类型:它有自己的类型系统,允许不同粒度级别以及常规 Python 类型中不可用的各种元信息。

我们首先创建一个类型类 IntervalType,并且由于我们不需要类型是参数化的,因此我们实例化一个单一类型实例 interval_type

from numba import types

class IntervalType(types.Type):
    def __init__(self):
        super(IntervalType, self).__init__(name='Interval')

interval_type = IntervalType()

Python 值的类型推断

Numba 类型的创建本身并不能做任何事情。我们必须教会 Numba 如何将某些 Python 值推断为该类型的实例。在此示例中,这很简单:Interval 类的任何实例都应被视为属于 interval_type 类型

from numba.extending import typeof_impl

@typeof_impl.register(Interval)
def typeof_index(val, c):
    return interval_type

因此,函数参数和全局值只要是 Interval 的实例,就会被识别为属于 interval_type

Python 注解的类型推断

虽然 typeof 用于推断 Python 对象的 Numba 类型,但 as_numba_type 用于推断 Python 类型的 Numba 类型。对于简单的情况,我们可以简单地注册 Python 类型 Interval 对应于 Numba 类型 interval_type

from numba.extending import as_numba_type

as_numba_type.register(Interval, interval_type)

请注意,as_numba_type 仅用于在编译时从类型注解推断类型。上面的 typeof 注册表用于在运行时推断对象的类型。

操作的类型推断

我们希望能够从 Numba 函数构造区间对象,因此我们必须教会 Numba 识别双参数 Interval(lo, hi) 构造函数。参数应为浮点数

from numba.extending import type_callable

@type_callable(Interval)
def type_interval(context):
    def typer(lo, hi):
        if isinstance(lo, types.Float) and isinstance(hi, types.Float):
            return interval_type
    return typer

type_callable() 装饰器指定在对给定可调用对象(此处为 Interval 类本身)运行类型推断时应调用被装饰的函数。被装饰的函数必须简单地返回一个将使用参数类型调用的类型器函数。这种看似复杂的设置的原因是,类型器函数必须与类型化的可调用对象具有完全相同的签名。这允许正确处理关键字参数。

被装饰函数接收的 context 参数在更复杂的情况下很有用,在这些情况下,计算可调用对象的返回类型需要解析其他类型。

扩展降低层

我们已经完成了关于类型推断扩展的 Numba 教学。现在我们必须教会 Numba 如何实际生成新操作的代码和数据。

定义原生区间的概念模型

通常,nopython 模式 不会对 CPython 解释器生成的 Python 对象进行操作。解释器使用的表示对于快速原生代码来说效率太低。因此,nopython 模式 中支持的每种类型都必须定义一个定制的原生表示,也称为数据模型

一种常见的数据模型是不可变的类结构体数据模型,类似于 C 语言中的 struct。我们的区间数据类型恰好属于这一类别,下面是一个可能的数据模型

from numba.extending import models, register_model

@register_model(IntervalType)
class IntervalModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [('lo', types.float64),
                   ('hi', types.float64),]
        models.StructModel.__init__(self, dmm, fe_type, members)

这指示 Numba,类型为 IntervalType 的值(或其任何实例)表示为包含两个字段 lohi 的结构体,每个字段都是一个双精度浮点数(types.float64)。

注意

可变类型需要更复杂的数据模型才能在修改后持久化其值。它们通常不能像不可变类型那样存储在堆栈或寄存器中进行传递。

公开数据模型属性

我们希望数据模型属性 lohi 以相同的名称公开,供 Numba 函数使用。Numba 提供了一个方便的函数来完成此操作

from numba.extending import make_attribute_wrapper

make_attribute_wrapper(IntervalType, 'lo', 'lo')
make_attribute_wrapper(IntervalType, 'hi', 'hi')

这将以只读模式公开属性。如上所述,可写属性不符合此模型。

公开属性

由于 width 属性是计算得出的,而不是存储在结构体中的,我们不能像处理 lohi 那样简单地公开它。我们必须显式地重新实现它

from numba.extending import overload_attribute

@overload_attribute(IntervalType, "width")
def get_width(interval):
    def getter(interval):
        return interval.hi - interval.lo
    return getter

您可能会问,为什么我们不需要为这个属性公开类型推断钩子?答案是 @overload_attribute 是高级 API 的一部分:它在一个 API 中结合了类型推断和代码生成。

实现构造函数

现在我们要实现双参数 Interval 构造函数

from numba.extending import lower_builtin
from numba.core import cgutils

@lower_builtin(Interval, types.Float, types.Float)
def impl_interval(context, builder, sig, args):
    typ = sig.return_type
    lo, hi = args
    interval = cgutils.create_struct_proxy(typ)(context, builder)
    interval.lo = lo
    interval.hi = hi
    return interval._getvalue()

这里还有一些细节。@lower_builtin 装饰器为给定可调用对象或操作(此处为 Interval 构造函数)的实现指定了某些特定的参数类型。这允许为给定操作定义特定于类型的实现,这对于像 len() 这样重载严重的函数非常重要。

types.Float 是所有浮点类型的类(types.float64types.Float 的一个实例)。通常,根据参数类型所在的类而不是特定的实例进行匹配,会更具前瞻性(然而,在返回类型时——主要是在类型推断阶段——通常必须返回一个类型实例)。

cgutils.create_struct_proxy()interval._getvalue() 由于 Numba 传递值的方式,会产生一些样板代码。值作为 llvmlite.ir.Value 的实例传递,这可能过于受限:特别是 LLVM 结构体值是相当底层的。结构体代理是 LLVM 结构体值的一个临时包装器,可以轻松地获取或设置结构体的成员。_getvalue() 调用只是从包装器中获取 LLVM 值。

装箱和拆箱

如果您此时尝试使用 Interval 实例,您肯定会收到错误 “无法将 Interval 转换为原生值”。这是因为 Numba 尚不知道如何从 Python Interval 实例创建原生区间值。让我们教它如何做到这一点

from numba.extending import unbox, NativeValue
from contextlib import ExitStack

@unbox(IntervalType)
def unbox_interval(typ, obj, c):
    """
    Convert a Interval object to a native interval structure.
    """
    is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit)
    interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)

    with ExitStack() as stack:
        lo_obj = c.pyapi.object_getattr_string(obj, "lo")
        with cgutils.early_exit_if_null(c.builder, stack, lo_obj):
            c.builder.store(cgutils.true_bit, is_error_ptr)
        lo_native = c.unbox(types.float64, lo_obj)
        c.pyapi.decref(lo_obj)
        with cgutils.early_exit_if(c.builder, stack, lo_native.is_error):
            c.builder.store(cgutils.true_bit, is_error_ptr)

        hi_obj = c.pyapi.object_getattr_string(obj, "hi")
        with cgutils.early_exit_if_null(c.builder, stack, hi_obj):
            c.builder.store(cgutils.true_bit, is_error_ptr)
        hi_native = c.unbox(types.float64, hi_obj)
        c.pyapi.decref(hi_obj)
        with cgutils.early_exit_if(c.builder, stack, hi_native.is_error):
            c.builder.store(cgutils.true_bit, is_error_ptr)

        interval.lo = lo_native.value
        interval.hi = hi_native.value

    return NativeValue(interval._getvalue(), is_error=c.builder.load(is_error_ptr))

拆箱 是“将 Python 对象转换为原生值”的另一个名称(它符合将 Python 对象视为一个包含简单原生值的复杂盒子的想法)。该函数返回一个 NativeValue 对象,该对象允许其调用者访问计算出的原生值、错误位以及可能存在的其他信息。

上面的代码片段大量使用了 c.pyapi 对象,该对象提供了对 Python 解释器 C API 子集的访问。请注意使用 early_exit_if_null 来检测和处理在拆箱对象时可能发生的任何错误(例如,尝试传递 Interval('a', 'b'))。

我们还想执行反向操作,称为装箱,以便从 Numba 函数返回区间值

from numba.extending import box

@box(IntervalType)
def box_interval(typ, val, c):
    """
    Convert a native interval structure to an Interval object.
    """
    ret_ptr = cgutils.alloca_once(c.builder, c.pyapi.pyobj)
    fail_obj = c.pyapi.get_null_object()

    with ExitStack() as stack:
        interval = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
        lo_obj = c.box(types.float64, interval.lo)
        with cgutils.early_exit_if_null(c.builder, stack, lo_obj):
            c.builder.store(fail_obj, ret_ptr)

        hi_obj = c.box(types.float64, interval.hi)
        with cgutils.early_exit_if_null(c.builder, stack, hi_obj):
            c.pyapi.decref(lo_obj)
            c.builder.store(fail_obj, ret_ptr)

        class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Interval))
        with cgutils.early_exit_if_null(c.builder, stack, class_obj):
            c.pyapi.decref(lo_obj)
            c.pyapi.decref(hi_obj)
            c.builder.store(fail_obj, ret_ptr)

        # NOTE: The result of this call is not checked as the clean up
        # has to occur regardless of whether it is successful. If it
        # fails `res` is set to NULL and a Python exception is set.
        res = c.pyapi.call_function_objargs(class_obj, (lo_obj, hi_obj))
        c.pyapi.decref(lo_obj)
        c.pyapi.decref(hi_obj)
        c.pyapi.decref(class_obj)
        c.builder.store(res, ret_ptr)

    return c.builder.load(ret_ptr)

使用它

nopython 模式 函数现在能够使用 Interval 对象以及您在它们上定义的各种操作。您可以尝试以下函数

from numba import njit

@njit
def inside_interval(interval, x):
    return interval.lo <= x < interval.hi

@njit
def interval_width(interval):
    return interval.width

@njit
def sum_intervals(i, j):
    return Interval(i.lo + j.lo, i.hi + j.hi)

结论

我们已经展示了如何完成以下任务

  • 通过继承 Type 类来定义新的 Numba 类型类

  • 为非参数类型定义一个单例 Numba 类型实例

  • 使用 typeof_impl.register 教会 Numba 如何推断特定类的 Python 值的 Numba 类型

  • 使用 as_numba_type.register 教会 Numba 如何推断 Python 类型本身的 Numba 类型

  • 使用 StructModelregister_model 定义 Numba 类型的数据模型

  • 使用 @box 装饰器实现 Numba 类型的装箱函数

  • 使用 @unbox 装饰器和 NativeValue 类实现 Numba 类型的拆箱函数

  • 使用 @type_callable@lower_builtin 装饰器定义和实现可调用对象

  • 使用 make_attribute_wrapper 便利函数公开只读结构体属性

  • 使用 @overload_attribute 装饰器实现只读属性