jax.lax.reduce_window#

jax.lax.reduce_window(operand, init_value, computation, window_dimensions, window_strides=None, padding='VALID', base_dilation=None, window_dilation=None)[源代码]#

在填充窗口上进行归约。

封装 XLA 的 ReduceWindowWithGeneralPadding 运算符。

参数:
  • operand (Any) – 输入数组或数组树。

  • init_value (Any) – 值或值树。树结构必须与 operand 的结构匹配。

  • computation (Callable) – 用于归约的可调用函数。输入和输出必须是与 operand 具有相同结构的树。

  • window_dimensions (core.Shape) – 指定窗口大小的整数序列。

  • window_strides (Sequence[int] | None) – 可选的整数序列,用于指定步幅,其长度与 window_dimensions 相同。默认值 (None) 表示每个窗口维度中的单位步幅。

  • padding (str | Sequence[tuple[int, int]]) – 字符串或整数元组序列,用于指定要使用的填充类型(默认值:“VALID”)。如果为字符串,则必须是“VALID”、“SAME”或“SAME_LOWER”之一。参见 jax.lax.padtype_to_pads() 实用程序。

  • base_dilation (Sequence[int] | None) – 用于基膨胀值的可选整数序列,其长度与 window_dimensions 相同。默认值 (None) 表示每个窗口维度中的单位膨胀。

  • window_dilation (Sequence[int] | None) – 用于窗口膨胀值的可选整数序列,其长度与 window_dimensions 相同。默认值 (None) 表示每个窗口维度中的单位膨胀。

返回:

operand 具有相同结构的数组树。

返回类型:

任意类型

示例

这是一个在 1 维数组中对成对元素进行窗口化乘积的简单示例

>>> import jax
>>> x = jax.numpy.arange(10, dtype='float32')
>>> x
Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)
>>> initial = jax.numpy.float32(1)
>>> jax.lax.reduce_window(x, initial, jax.lax.mul, window_dimensions=(2,))
Array([ 0.,  2.,  6., 12., 20., 30., 42., 56., 72.], dtype=float32)