jax.numpy.put_along_axis#

jax.numpy.put_along_axis(arr, indices, values, axis, inplace=True, *, mode=None)[源代码]#

通过匹配 1d 索引和数据切片将值放入目标数组。

numpy.put_along_axis() 的 JAX 实现。

numpy.put_along_axis() 的语义是原位修改数组,这对于 JAX 的不可变数组是不可能的。JAX 版本返回输入的修改副本,并添加 inplace 参数,用户必须将其设置为 False` 以提醒此 API 差异。

参数:
返回:

更新了指定条目的 a 的副本。

返回类型:

Array

另请参阅

示例

>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> print(i)
[[1]
 [0]]
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
>>> print(b)
[[10 99 20]
 [99 40 50]]