jax.numpy.take#
- jax.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#
从数组中提取元素。
numpy.take()
的 JAX 实现,根据jax.lax.gather()
实现。在越界索引的情况下,JAX 的行为与 NumPy 不同;请参阅下面的mode
参数。- 参数:
a (Array | ndarray | bool | number | bool | int | float | complex) – 从中提取值的数组。
indices (Array | ndarray | bool | number | bool | int | float | complex) – 从数组中提取值的 N 维整数索引数组。
axis (int | None) – 沿其提取值的轴。如果未指定,则在应用索引之前将展平数组。
mode (str | None) – 越界索引模式,可以是
"fill"
或"clip"
。默认的mode="fill"
为越界索引返回无效值(例如 NaN);fill_value
参数控制此值。有关mode
选项的更多讨论,请参阅jax.numpy.ndarray.at
。fill_value (bool | number | bool | int | float | complex | None) – 当模式为 ‘fill’ 时,为越界切片返回的填充值。否则忽略。默认为非精确类型的 NaN,有符号类型的最大负值,无符号类型的最大正值,以及布尔值的 True。
unique_indices (bool) – 如果为 True,则实现将假定索引是唯一的,这可以在某些后端上实现更高效的执行。如果设置为 True 且索引不是唯一的,则输出未定义。
indices_are_sorted (bool) – 如果为 True,则实现将假定索引按升序排序,这可以在某些后端上实现更高效的执行。如果设置为 True 且索引未排序,则输出未定义。
out (None)
- 返回:
从
a
中提取的值的数组。- 返回类型:
参见
jax.numpy.ndarray.at
:通过索引语法提取值。
示例
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 6.]]) >>> indices = jnp.array([2, 0])
不传递轴会导致索引到展平的数组中
>>> jnp.take(x, indices) Array([3., 1.], dtype=float32) >>> x.ravel()[indices] # equivalent indexing syntax Array([3., 1.], dtype=float32)
传递轴会导致将索引应用于沿轴的每个子数组
>>> jnp.take(x, indices, axis=1) Array([[3., 1.], [6., 4.]], dtype=float32) >>> x[:, indices] # equivalent indexing syntax Array([[3., 1.], [6., 4.]], dtype=float32)
越界索引填充无效值。对于浮点输入,这是 NaN
>>> jnp.take(x, indices, axis=0) Array([[nan, nan, nan], [ 1., 2., 3.]], dtype=float32) >>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax Array([[nan, nan, nan], [ 1., 2., 3.]], dtype=float32)
可以使用
mode
参数调整此默认的越界行为,例如,我们可以改为裁剪到最后一个有效值>>> jnp.take(x, indices, axis=0, mode='clip') Array([[4., 5., 6.], [1., 2., 3.]], dtype=float32) >>> x.at[indices].get(mode='clip') # equivalent indexing syntax Array([[4., 5., 6.], [1., 2., 3.]], dtype=float32)