jax.lax.dynamic_update_slice#
- jax.lax.dynamic_update_slice(operand, update, start_indices, *, allow_negative_indices=True)[源代码]#
封装 XLA 的 DynamicUpdateSlice 运算符。
- 参数:
operand (Array | np.ndarray) – 要切片的数组。
update (ArrayLike) – 包含要写入 operand 的新值的数组。
start_indices (Array | Sequence[ArrayLike]) – 标量索引列表,每个维度一个。
allow_negative_indices (bool | Sequence[bool]) – 一个布尔值或布尔值序列,每个维度一个;如果传递一个布尔值,它将应用于所有维度。对于每个维度,如果为 true,则允许使用负索引,并且将其解释为相对于数组末尾。如果为 false,则将负索引视为超出范围,并且结果由实现定义,通常钳制到第一个索引。
- 返回:
包含切片的数组。
- 返回类型:
示例
这是一个更新一维切片更新的示例
>>> x = jnp.zeros(6) >>> y = jnp.ones(3) >>> dynamic_update_slice(x, y, (2,)) Array([0., 0., 1., 1., 1., 0.], dtype=float32)
如果更新切片太大而无法放入数组,则将调整起始索引以使其适合
>>> dynamic_update_slice(x, y, (3,)) Array([0., 0., 0., 1., 1., 1.], dtype=float32) >>> dynamic_update_slice(x, y, (5,)) Array([0., 0., 0., 1., 1., 1.], dtype=float32)
这是一个二维切片更新的示例
>>> x = jnp.zeros((4, 4)) >>> y = jnp.ones((2, 2)) >>> dynamic_update_slice(x, y, (1, 2)) Array([[0., 0., 0., 0.], [0., 0., 1., 1.], [0., 0., 1., 1.], [0., 0., 0., 0.]], dtype=float32)
另请参阅
lax.dynamic_update_index_in_dim
lax.dynamic_update_slice_in_dim