jax.numpy.put_along_axis#
- jax.numpy.put_along_axis(arr, indices, values, axis, inplace=True, *, mode=None)[源代码]#
通过匹配1维索引和数据切片将值放入目标数组中。
numpy.put_along_axis()
的 JAX 实现。numpy.put_along_axis()
的语义是就地修改数组,这对于 JAX 的不可变数组来说是不可能的。 JAX 版本返回输入的修改后的副本,并添加inplace
参数,用户必须将其设置为 False 以提醒此 API 差异。- 参数:
arr (ArrayLike) – 将要放入值的数组。
indices (ArrayLike) – 要在其中放入值的索引数组。
values (ArrayLike) – 要放入数组的值的数组。
axis (int | None) – 沿其放入值的轴。 如果未指定,则在应用索引之前将展平数组。
inplace (bool) – 必须设置为 False 以表明输入不是就地修改的,而是返回修改后的副本。
mode (str | None) – 超出范围的索引模式。 有关
mode
选项的更多讨论,请参阅jax.numpy.ndarray.at
。
- 返回值:
更新了指定条目的
a
的副本。- 返回类型:
另请参阅
jax.numpy.put()
:将元素放入给定索引的数组中。jax.numpy.place()
:通过布尔掩码将元素放入数组中。jax.numpy.ndarray.at()
:使用 NumPy 样式的索引进行数组更新。jax.numpy.take()
:从给定索引的数组中提取值。jax.numpy.take_along_axis()
:沿轴从数组中提取值。
示例
>>> from jax import numpy as jnp >>> a = jnp.array([[10, 30, 20], [60, 40, 50]]) >>> i = jnp.argmax(a, axis=1, keepdims=True) >>> print(i) [[1] [0]] >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False) >>> print(b) [[10 99 20] [99 40 50]]