jax.numpy.extract#

jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[source]#

返回满足条件的数组元素。

JAX 对 numpy.extract() 的实现。

参数:
  • condition (ArrayLike) – 条件数组。将被转换为布尔类型并展平为一维。

  • arr (ArrayLike) – 要提取的数值数组。将被展平为一维。

  • size (int | None) – 输出的可选静态大小。必须指定此参数,以使 extract 与 JAX 变换(如 jit()vmap())兼容。

  • fill_value (ArrayLike) – 如果指定了 size,则用此值填充填充项(默认值:0)。

返回:

提取条目的一维数组。如果指定了 size,结果将具有形状 (size,) 并用 fill_value 进行右填充。如果未指定 size,则输出形状将取决于 condition 中为 True 的条目数量。

返回类型:

数组

注意事项

此函数不要求 conditionarr 之间具有严格的形状一致性。如果 condition.size > arr.size,则 condition 将被截断;如果 arr.size > condition.size,则 arr 将被截断。

另请参阅

jax.numpy.compress()extract 的多维版本。

示例

从一维数组中提取值

>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> mask = (x % 2 == 0)
>>> jnp.extract(mask, x)
Array([2, 4, 6], dtype=int32)

在最简单的情况下,这等同于布尔索引。

>>> x[mask]
Array([2, 4, 6], dtype=int32)

为了与 JAX 变换一起使用,您可以传递 size 参数来指定输出的静态形状,以及一个可选的 fill_value(默认为零)

>>> jnp.extract(mask, x, size=len(x), fill_value=0)
Array([2, 4, 6, 0, 0, 0], dtype=int32)

请注意,与布尔索引不同,extract 不要求数组和条件的尺寸之间有严格的一致性,并且会有效地将两者截断到最小尺寸。

>>> short_mask = jnp.array([False, True])
>>> jnp.extract(short_mask, x)
Array([2], dtype=int32)
>>> long_mask = jnp.array([True, False, True, False, False, False, False, False])
>>> jnp.extract(long_mask, x)
Array([1, 3], dtype=int32)