jax.lax.dynamic_update_slice#

jax.lax.dynamic_update_slice(operand, update, start_indices, *, allow_negative_indices=True)[source]#

包装 XLA 的 DynamicUpdateSlice 运算符。

参数:
  • operand (Array | np.ndarray) – 要切片的数组。

  • update (ArrayLike) – 包含要写入 operand 的新值的数组。

  • start_indices (Array | Sequence[ArrayLike]) – 标量索引列表,每维度一个。

  • allow_negative_indices (bool | Sequence[bool]) – 布尔值或布尔值序列,每维度一个;如果传递布尔值,则应用于所有维度。对于每个维度,如果为 true,则允许负索引,并且相对于数组末尾解释。如果为 false,则负索引被视为超出边界,结果是实现定义的,通常钳制为第一个索引。

返回:

包含切片的数组。

返回类型:

Array

示例

这是一个更新一维切片更新的示例

>>> x = jnp.zeros(6)
>>> y = jnp.ones(3)
>>> dynamic_update_slice(x, y, (2,))
Array([0., 0., 1., 1., 1., 0.], dtype=float32)

如果更新切片太大而无法放入数组,则将调整起始索引以使其适合

>>> dynamic_update_slice(x, y, (3,))
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
>>> dynamic_update_slice(x, y, (5,))
Array([0., 0., 0., 1., 1., 1.], dtype=float32)

这是一个二维切片更新的示例

>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones((2, 2))
>>> dynamic_update_slice(x, y, (1, 2))
Array([[0., 0., 0., 0.],
       [0., 0., 1., 1.],
       [0., 0., 1., 1.],
       [0., 0., 0., 0.]], dtype=float32)

另请参阅