jax.numpy.take#

jax.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#

从数组中选取元素。

numpy.take() 的 JAX 实现,基于 jax.lax.gather() 实现。 在越界索引的情况下,JAX 的行为与 NumPy 不同; 请参阅下面的 mode 参数。

参数:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – 从中取值的数组。

  • indices (Array | ndarray | bool | number | bool | int | float | complex) – 从数组中取值的整数索引的 N 维数组。

  • axis (int | None) – 从中取值的轴。 如果未指定,则将在应用索引之前将数组展平。

  • mode (str | None) – 越界索引模式,"fill""clip"。 默认的 mode="fill" 返回越界索引的无效值(例如 NaN); fill_value 参数可以控制该值。 有关 mode 选项的更多讨论,请参阅 jax.numpy.ndarray.at

  • fill_value (bool | number | bool | int | float | complex | None) – 当 mode 为 ‘fill’ 时,为越界切片返回的填充值。 否则将被忽略。 对于非精确类型,默认为 NaN; 对于有符号类型,默认为最大负值; 对于无符号类型,默认为最大正值; 对于布尔值,默认为 True。

  • unique_indices (bool) – 如果为 True,则实现将假定索引是唯一的,这可能会导致某些后端上更有效的执行。 如果设置为 True 且索引不唯一,则输出未定义。

  • indices_are_sorted (bool) – 如果为 True,则实现将假定索引按升序排序,这可能会导致某些后端上更有效的执行。 如果设置为 True 且索引未排序,则输出未定义。

  • out (None)

返回:

a 中提取的值的数组。

返回类型:

Array

另请参阅

示例

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 6.]])
>>> indices = jnp.array([2, 0])

不传递轴会导致索引到展平的数组中

>>> jnp.take(x, indices)
Array([3., 1.], dtype=float32)
>>> x.ravel()[indices]  # equivalent indexing syntax
Array([3., 1.], dtype=float32)

传递轴会导致将索引应用于沿轴的每个子数组

>>> jnp.take(x, indices, axis=1)
Array([[3., 1.],
       [6., 4.]], dtype=float32)
>>> x[:, indices]  # equivalent indexing syntax
Array([[3., 1.],
       [6., 4.]], dtype=float32)

越界索引填充无效值。 对于浮点输入,这是 NaN

>>> jnp.take(x, indices, axis=0)
Array([[nan, nan, nan],
       [ 1.,  2.,  3.]], dtype=float32)
>>> x.at[indices].get(mode='fill', fill_value=jnp.nan)  # equivalent indexing syntax
Array([[nan, nan, nan],
       [ 1.,  2.,  3.]], dtype=float32)

可以使用 mode 参数调整此默认的越界行为,例如,我们可以改为剪裁到最后一个有效值

>>> jnp.take(x, indices, axis=0, mode='clip')
Array([[4., 5., 6.],
       [1., 2., 3.]], dtype=float32)
>>> x.at[indices].get(mode='clip')  # equivalent indexing syntax
Array([[4., 5., 6.],
       [1., 2., 3.]], dtype=float32)