jax.numpy.put#
- jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[源代码]#
将元素按指定索引放入数组。
JAX 对
numpy.put()的实现。numpy.put()的语义是就地修改数组,这对于 JAX 的不可变数组是不可能的。JAX 版本返回输入数组的修改副本,并添加了inplace参数,用户必须将其设置为 False``,以提醒注意此 API 差异。- 参数:
a (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – 要放入值的数组。
ind (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – 在展平数组中放置值的索引数组。
v (Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray) – 要放入数组的值数组。
mode (str | None) –
指定如何处理越界索引的字符串。支持的值
"clip"(默认): 将越界索引剪辑到最后一个索引。"wrap": 将越界索引环绕到数组的开头。
inplace (bool) – 必须设置为 False,以指示输入不会被原地修改,而是返回一个修改后的副本。
- 返回:
一个
a的副本,其中包含指定的更新条目。- 返回类型:
另请参阅
jax.numpy.place(): 通过布尔掩码将元素放置到数组中。jax.numpy.ndarray.at(): 使用 NumPy 风格的索引进行数组更新。jax.numpy.take(): 按指定索引从数组中提取值。
示例
>>> x = jnp.zeros(5, dtype=int) >>> indices = jnp.array([0, 2, 4]) >>> values = jnp.array([10, 20, 30]) >>> jnp.put(x, indices, values, inplace=False) Array([10, 0, 20, 0, 30], dtype=int32)
这等效于以下
jax.numpy.ndarray.at索引语法>>> x.at[indices].set(values) Array([10, 0, 20, 0, 30], dtype=int32)
有两种处理越界索引的模式。默认情况下,它们会被剪辑
>>> indices = jnp.array([0, 2, 6]) >>> jnp.put(x, indices, values, inplace=False, mode='clip') Array([10, 0, 20, 0, 30], dtype=int32)
或者,它们可以被环绕到数组的开头
>>> jnp.put(x, indices, values, inplace=False, mode='wrap') Array([10, 30, 20, 0, 0], dtype=int32)
对于 N 维输入,索引指的是展平后的数组
>>> x = jnp.zeros((3, 5), dtype=int) >>> indices = jnp.array([0, 7, 14]) >>> jnp.put(x, indices, values, inplace=False) Array([[10, 0, 0, 0, 0], [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32)