jax.lax.dynamic_slice#
- jax.lax.dynamic_slice(operand, start_indices, slice_sizes, *, allow_negative_indices=True)[源代码]#
包装 XLA 的 DynamicSlice 运算符。
- 参数:
operand (Array | np.ndarray) – 要切片的数组。
start_indices (Array | np.ndarray | Sequence[ArrayLike]) – 每个维度一个标量索引的列表。这些值可能是动态的。
slice_sizes (Shape) – 切片的大小。必须是长度等于 ndim(operand) 的非负整数序列。 在 JIT 编译函数中,仅支持静态值(JIT 中的所有 JAX 数组都必须具有静态已知的大小)。
allow_negative_indices (bool | Sequence[bool]) – 一个布尔值或布尔值序列,每个维度一个; 如果传递一个布尔值,它将应用于所有维度。 对于每个维度,如果为 true,则允许负索引,并且相对于数组的末尾解释。 如果为 false,则将负索引视为超出范围,并且结果是实现定义的,通常钳制到第一个索引。
- 返回:
包含切片的数组。
- 返回类型:
示例
这是一个简单的二维动态切片
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3)) Array([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32)
请注意请求的切片超出数组边界的情况下的潜在意外行为; 在这种情况下,起始索引被调整为返回请求大小的切片
>>> dynamic_slice(x, (1, 1), (2, 4)) Array([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)