jax.lax.reduce_window#
- jax.lax.reduce_window(operand, init_value, computation, window_dimensions, window_strides=None, padding='VALID', base_dilation=None, window_dilation=None)[源代码]#
在填充窗口上进行归约。
封装 XLA 的 ReduceWindowWithGeneralPadding 运算符。
- 参数:
operand (Any) – 输入数组或数组树。
init_value (Any) – 值或值树。树结构必须与
operand
的结构匹配。computation (Callable) – 用于归约的可调用函数。输入和输出必须是与
operand
具有相同结构的树。window_dimensions (core.Shape) – 指定窗口大小的整数序列。
window_strides (Sequence[int] | None) – 可选的整数序列,用于指定步幅,其长度与
window_dimensions
相同。默认值 (None
) 表示每个窗口维度中的单位步幅。padding (str | Sequence[tuple[int, int]]) – 字符串或整数元组序列,用于指定要使用的填充类型(默认值:“VALID”)。如果为字符串,则必须是“VALID”、“SAME”或“SAME_LOWER”之一。参见
jax.lax.padtype_to_pads()
实用程序。base_dilation (Sequence[int] | None) – 用于基膨胀值的可选整数序列,其长度与
window_dimensions
相同。默认值 (None
) 表示每个窗口维度中的单位膨胀。window_dilation (Sequence[int] | None) – 用于窗口膨胀值的可选整数序列,其长度与
window_dimensions
相同。默认值 (None
) 表示每个窗口维度中的单位膨胀。
- 返回:
与
operand
具有相同结构的数组树。- 返回类型:
任意类型
示例
这是一个在 1 维数组中对成对元素进行窗口化乘积的简单示例
>>> import jax >>> x = jax.numpy.arange(10, dtype='float32') >>> x Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)
>>> initial = jax.numpy.float32(1) >>> jax.lax.reduce_window(x, initial, jax.lax.mul, window_dimensions=(2,)) Array([ 0., 2., 6., 12., 20., 30., 42., 56., 72.], dtype=float32)