使用 @stencil
装饰器
模板(Stencil)是一种常见的计算模式,其中数组元素根据称为模板核(stencil kernel)的固定模式进行更新。Numba 提供了 @stencil
装饰器,以便用户可以轻松指定一个模板核,然后 Numba 生成必要的循环代码以将该核应用于某个输入数组。因此,模板装饰器允许更清晰、更简洁的代码,并且与并行 jit 选项结合使用,通过模板执行的并行化实现更高的性能。
基本用法
@stencil
装饰器的使用示例
from numba import stencil
@stencil
def kernel1(a):
return 0.25 * (a[0, 1] + a[1, 0] + a[0, -1] + a[-1, 0])
模板核通过看起来像标准 Python 函数定义的方式指定,但在数组索引方面有不同的语义。模板生成一个与输入数组大小和形状相同的输出数组,尽管根据核定义其类型可能不同。从概念上讲,模板核对输出数组中的每个元素运行一次。模板核的返回值是写入输出数组中该特定元素的值。
参数 a
表示应用核的输入数组。对此数组的索引是相对于正在处理的输出数组的当前元素进行的。例如,如果正在处理元素 (x, y)
,则模板核中的 a[0, 0]
对应于输入数组中的 a[x + 0, y + 0]
。类似地,模板核中的 a[-1, 1]
对应于输入数组中的 a[x - 1, y + 1]
。
根据指定的核,该核可能不适用于输出数组的边界,因为这可能导致输入数组越界访问。模板装饰器处理这种情况的方式取决于选择哪个func_or_mode。默认模式是模板装饰器将输出数组的边界元素设置为零。
要在输入数组上调用模板,请像调用常规函数一样调用模板,并将输入数组作为参数传递。例如,使用上面定义的核
>>> import numpy as np
>>> input_arr = np.arange(100).reshape((10, 10))
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
[50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
[60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
[70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
[80, 81, 82, 83, 84, 85, 86, 87, 88, 89],
[90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])
>>> output_arr = kernel1(input_arr)
array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 11., 12., 13., 14., 15., 16., 17., 18., 0.],
[ 0., 21., 22., 23., 24., 25., 26., 27., 28., 0.],
[ 0., 31., 32., 33., 34., 35., 36., 37., 38., 0.],
[ 0., 41., 42., 43., 44., 45., 46., 47., 48., 0.],
[ 0., 51., 52., 53., 54., 55., 56., 57., 58., 0.],
[ 0., 61., 62., 63., 64., 65., 66., 67., 68., 0.],
[ 0., 71., 72., 73., 74., 75., 76., 77., 78., 0.],
[ 0., 81., 82., 83., 84., 85., 86., 87., 88., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
>>> input_arr.dtype
dtype('int64')
>>> output_arr.dtype
dtype('float64')
请注意,模板装饰器已确定指定模板核的输出类型为 float64
,因此已将输出数组创建为 float64
,而输入数组的类型为 int64
。
模板参数
模板核定义可以接受任意数量的参数,但有以下规定。第一个参数必须是数组。输出数组的大小和形状将与第一个参数相同。其他参数可以是标量或数组。对于数组参数,这些数组在每个维度上必须至少与第一个参数(数组)一样大。所有此类输入数组参数的数组索引都是相对的。
核形状推断和边界处理
在上面的示例和大多数情况下,模板核中的数组索引将仅使用 Integer
字面量。在这种情况下,模板装饰器能够分析模板核以确定其大小。在上面的示例中,模板装饰器确定核的形状为 3 x 3
,因为索引 -1
到 1
用于第一和第二维度。请注意,模板装饰器还正确处理非对称和非方形模板核。
根据模板核的大小,模板装饰器能够计算输出数组中边界的大小。如果将核应用于输入数组的某个元素会导致索引越界,则该元素属于输出数组的边界。在上面的示例中,在每个维度中访问点 -1
和 +1
,因此输出数组在所有维度上都有一个大小为一的边界。
并行模式如果可能,能够从简单的表达式中推断出核索引作为常量。例如
@njit(parallel=True)
def stencil_test(A):
c = 2
B = stencil(
lambda a, c: 0.3 * (a[-c+1] + a[0] + a[c-1]))(A, c)
return B
模板装饰器选项
注意
模板装饰器未来可能会得到增强,以提供额外的边界处理机制。目前,只实现了一种行为,即 "constant"
(详见下面的 func_or_mode
)。
neighborhood
有时,仅使用 Integer
字面量来编写模板核可能会不方便。例如,假设我们想计算一个时间序列数据的滞后 30 天移动平均值。我们可以写 (a[-29] + a[-28] + ... + a[-1] + a[0]) / 30
,但模板装饰器使用 neighborhood
选项提供了更简洁的形式
@stencil(neighborhood = ((-29, 0),))
def kernel2(a):
cumul = 0
for i in range(-29, 1):
cumul += a[i]
return cumul / 30
邻域选项是一个元组的元组。外部元组的长度等于输入数组的维度数。内部元组的长度始终为二,因为内部元组的每个元素对应于相应维度中使用的最小和最大索引偏移量。
如果用户指定了邻域但核访问了指定邻域之外的元素,**行为是未定义的。**
func_or_mode
可选的 func_or_mode
参数控制输出数组边界的处理方式。目前,只有一个受支持的值,即 "constant"
。在 constant
模式下,如果核将访问输入数组有效范围之外的元素,则不应用模板核。在这种情况下,输出数组中的这些元素将被赋予一个常量值,该值由 cval
参数指定。
cval
可选的 cval
参数默认为零,但可以设置为任何所需的值,如果 func_or_mode
参数设置为 constant
,则该值将用于输出数组的边界。在所有其他模式下,cval
参数被忽略。cval
参数的类型必须与模板核的返回类型匹配。如果用户希望输出数组由特定类型构造,则应确保模板核返回该类型。
standard_indexing
默认情况下,模板核中的所有数组访问都作为相对索引进行处理,如上所述。然而,有时将辅助数组(例如,权重数组)传递给模板核并让该数组使用标准 Python 索引而不是相对索引可能会更有利。为此,存在模板装饰器选项 standard_indexing
,其值是一个字符串集合,这些字符串的名称与要使用标准 Python 索引而非相对索引访问的模板函数参数匹配
@stencil(standard_indexing=("b",))
def kernel3(a, b):
return a[-1] * b[0] + a[0] + b[1]
StencilFunc
模板装饰器返回一个类型为 StencilFunc
的可调用对象。一个 StencilFunc
对象包含多个属性,但用户可能感兴趣的唯一属性是 neighborhood
属性。如果将 neighborhood
选项传递给模板装饰器,则提供的邻域将存储在此属性中。否则,在首次执行或编译时,系统将如上所述计算邻域,然后将计算出的邻域存储到此属性中。用户可以检查该属性,以验证计算出的邻域是否正确。
模板调用选项
在内部,模板装饰器将指定的模板核转换为常规的 Python 函数。此函数将具有与模板核定义中指定的参数相同的参数,但还将包含以下可选参数。
out
可选的 out
参数被添加到 Numba 生成的每个模板函数中。如果指定,out
参数告诉 Numba 用户正在提供自己预分配的数组,用于模板的输出。在这种情况下,模板函数将不会分配自己的输出数组。用户应确保模板核的返回类型可以安全地转换为用户指定输出数组的元素类型,遵循NumPy ufunc 转换规则。
使用示例如下所示
>>> import numpy as np
>>> input_arr = np.arange(100).reshape((10, 10))
>>> output_arr = np.full(input_arr.shape, 0.0)
>>> kernel1(input_arr, out=output_arr)