jax.numpy.put_along_axis#

jax.numpy.put_along_axis(arr, indices, values, axis, inplace=True, *, mode=None)[源代码]#

通过匹配1维索引和数据切片将值放入目标数组中。

numpy.put_along_axis() 的 JAX 实现。

numpy.put_along_axis() 的语义是就地修改数组,这对于 JAX 的不可变数组来说是不可能的。 JAX 版本返回输入的修改后的副本,并添加 inplace 参数,用户必须将其设置为 False 以提醒此 API 差异。

参数:
  • arr (ArrayLike) – 将要放入值的数组。

  • indices (ArrayLike) – 要在其中放入值的索引数组。

  • values (ArrayLike) – 要放入数组的值的数组。

  • axis (int | None) – 沿其放入值的轴。 如果未指定,则在应用索引之前将展平数组。

  • inplace (bool) – 必须设置为 False 以表明输入不是就地修改的,而是返回修改后的副本。

  • mode (str | None) – 超出范围的索引模式。 有关 mode 选项的更多讨论,请参阅 jax.numpy.ndarray.at

返回值:

更新了指定条目的 a 的副本。

返回类型:

Array

另请参阅

示例

>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> print(i)
[[1]
 [0]]
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
>>> print(b)
[[10 99 20]
 [99 40 50]]