jax.ref.addupdate#

jax.ref.addupdate(ref, idx, x)[源代码]#

就地向 Ref 中的元素添加值。

这类似于 NumPy 数组 ref 和 NumPy 风格的索引器 idxref[idx] += value。然而,对于 Ref ref,执行 ref[idx] += value 实际上会执行一次 ref_get、加法和 ref_set,因此在使用自动微分时,使用此函数可能更有效。有关可变数组 Ref 的更多信息,请参阅 Ref 指南

参数:
  • ref (AbstractRef) – 一个 jax.ref.Ref 对象。返回时,缓冲区将被此操作修改。

  • idx (Indexer | tuple[Indexer, ...] | None) – 一个 NumPy 风格的索引器

  • x (Array) – 一个 jax.Array 对象(注意,不是 jax.ref.Ref),其中包含要在指定索引处添加的值。

返回:

返回类型:

示例

>>> import jax
>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> jax.ref.addupdate(ref, 2, 10)
>>> ref
Ref([ 0,  1, 12,  3,  4], dtype=int32)

通过索引语法实现等效操作

>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> ref[2] += 10
>>> ref
Ref([ 0,  1, 12,  3,  4], dtype=int32)

使用 ... 向标量 Ref 添加值

>>> ref = jax.new_ref(jax.numpy.int32(2))
>>> ref[...] += 10
>>> ref
Ref(12, dtype=int32)