jax.numpy.put#

jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[源代码]#

将元素按指定索引放入数组。

JAX 对 numpy.put() 的实现。

numpy.put() 的语义是就地修改数组,这对于 JAX 的不可变数组是不可能的。JAX 版本返回输入数组的修改副本,并添加了 inplace 参数,用户必须将其设置为 False``,以提醒注意此 API 差异。

参数:
  • a (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – 要放入值的数组。

  • ind (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – 在展平数组中放置值的索引数组。

  • v (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – 要放入数组的值数组。

  • mode (str | None) –

    指定如何处理越界索引的字符串。支持的值

    • "clip" (默认): 将越界索引剪辑到最后一个索引。

    • "wrap": 将越界索引环绕到数组的开头。

  • inplace (bool) – 必须设置为 False,以指示输入不会被原地修改,而是返回一个修改后的副本。

返回:

一个 a 的副本,其中包含指定的更新条目。

返回类型:

Array

另请参阅

示例

>>> x = jnp.zeros(5, dtype=int)
>>> indices = jnp.array([0, 2, 4])
>>> values = jnp.array([10, 20, 30])
>>> jnp.put(x, indices, values, inplace=False)
Array([10,  0, 20,  0, 30], dtype=int32)

这等效于以下 jax.numpy.ndarray.at 索引语法

>>> x.at[indices].set(values)
Array([10,  0, 20,  0, 30], dtype=int32)

有两种处理越界索引的模式。默认情况下,它们会被剪辑

>>> indices = jnp.array([0, 2, 6])
>>> jnp.put(x, indices, values, inplace=False, mode='clip')
Array([10,  0, 20,  0, 30], dtype=int32)

或者,它们可以被环绕到数组的开头

>>> jnp.put(x, indices, values, inplace=False, mode='wrap')
Array([10,  30, 20,  0, 0], dtype=int32)

对于 N 维输入,索引指的是展平后的数组

>>> x = jnp.zeros((3, 5), dtype=int)
>>> indices = jnp.array([0, 7, 14])
>>> jnp.put(x, indices, values, inplace=False)
Array([[10,  0,  0,  0,  0],
       [ 0,  0, 20,  0,  0],
       [ 0,  0,  0,  0, 30]], dtype=int32)