jax.lax.dynamic_slice_in_dim#

jax.lax.dynamic_slice_in_dim(operand, start_index, slice_size, axis=0, *, allow_negative_indices=True)[source]#

围绕应用于一个维度的 lax.dynamic_slice() 的便捷包装器。

这大致等效于沿指定轴应用的以下 Python 索引语法:operand[..., start_index:start_index + slice_size]

参数:
  • operand (Array | np.ndarray) – 要切片的数组。

  • start_index (ArrayLike) – (可能是动态的)起始索引

  • slice_size (int) – 静态切片大小

  • axis (int) – 应用切片的轴(默认为 0)

  • allow_negative_indices (bool) – 布尔值,指定是否允许负索引。如果为 true,则负索引相对于数组末尾。如果为 false,则负索引超出范围,结果是实现定义的。

返回:

包含切片的数组。

返回类型:

Array

示例

这是一个一维示例

>>> x = jnp.arange(5)
>>> dynamic_slice_in_dim(x, 1, 3)
Array([1, 2, 3], dtype=int32)

jax.lax.dynamic_slice 类似,超出范围的切片将被裁剪到有效范围

>>> dynamic_slice_in_dim(x, 4, 3)
Array([2, 3, 4], dtype=int32)

这是一个二维示例

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> dynamic_slice_in_dim(x, 1, 2, axis=1)
Array([[ 1,  2],
       [ 5,  6],
       [ 9, 10]], dtype=int32)