jax.lax.gather#

jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)[source]#

Gather 运算符。

包装了 XLA 的 Gather 运算符

gather() 是一个底层运算符,具有复杂的语义,大多数 JAX 用户永远不需要直接调用它。相反,您应该更喜欢使用 Numpy 风格的索引,和/或 jax.numpy.ndarray.at(),或许可以与 jax.vmap() 结合使用。

参数:
  • operand (ArrayLike) – 应该从中提取切片的数组

  • start_indices (ArrayLike) – 应该在其中提取切片的索引

  • dimension_numbers (GatherDimensionNumbers) – 一个 lax.GatherDimensionNumbers 对象,描述了 operandstart_indices 和输出的维度如何关联。

  • slice_sizes (Shape) – 每个切片的大小。必须是非负整数序列,长度等于 ndim(operand)

  • indices_are_sorted (bool) – indices 是否已知已排序。如果为 true,可能会提高某些后端上的性能。

  • unique_indices (bool) – 从 operand 收集的元素是否保证彼此不重叠。如果为 True,这可能会提高某些后端上的性能。JAX 不会检查此承诺:如果元素重叠,则行为未定义。

  • mode (str | GatherScatterMode | None | None) – 如何处理越界索引:当设置为 'clip' 时,索引将被钳制,以便切片在界限内;当设置为 'fill''drop' 时,gather 将为受影响的切片返回充满 fill_value 的切片。当设置为 'promise_in_bounds' 时,越界索引的行为是实现定义的。

  • fill_value – 当 mode'fill' 时,为越界切片返回的填充值。否则将被忽略。对于非精确类型,默认为 NaN;对于有符号类型,默认为最大负值;对于无符号类型,默认为最大正值;对于布尔值,默认为 True

返回:

包含 gather 输出的数组。

返回类型:

Array

示例

如上所述,您基本上永远不应该直接使用 gather(),而应该使用 Numpy 风格的索引表达式来从数组中收集值。

例如,以下是如何使用直接索引语义提取特定索引处的值,这将降低到 XLA 的 Gather 运算符

>>> import jax.numpy as jnp
>>> x = jnp.array([10, 11, 12])
>>> indices = jnp.array([0, 1, 1, 2, 2, 2])
>>> x[indices]
Array([10, 11, 11, 12, 12, 12], dtype=int32)

为了控制 indices_are_sortedunique_indicesmodefill_value 等设置,您可以使用 jax.numpy.ndarray.at 语法

>>> x.at[indices].get(indices_are_sorted=True, mode="promise_in_bounds")
Array([10, 11, 11, 12, 12, 12], dtype=int32)

相比之下,以下是直接使用 gather() 的等效函数调用,这不是典型用户应该做的

>>> from jax import lax
>>> lax.gather(x, indices[:, None], slice_sizes=(1,),
...            dimension_numbers=lax.GatherDimensionNumbers(
...                offset_dims=(),
...                collapsed_slice_dims=(0,),
...                start_index_map=(0,)),
...            indices_are_sorted=True,
...            mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
Array([10, 11, 11, 12, 12, 12], dtype=int32)