jax.ref.addupdate#
- jax.ref.addupdate(ref, idx, x)[源代码]#
就地向 Ref 中的元素添加值。
这类似于 NumPy 数组
ref
和 NumPy 风格的索引器idx
的ref[idx] += value
。然而,对于 Refref
,执行ref[idx] += value
实际上会执行一次ref_get
、加法和ref_set
,因此在使用自动微分时,使用此函数可能更有效。有关可变数组 Ref 的更多信息,请参阅 Ref 指南。- 参数:
ref (AbstractRef) – 一个
jax.ref.Ref
对象。返回时,缓冲区将被此操作修改。idx (Indexer | tuple[Indexer, ...] | None) – 一个 NumPy 风格的索引器
x (Array) – 一个
jax.Array
对象(注意,不是jax.ref.Ref
),其中包含要在指定索引处添加的值。
- 返回:
无
- 返回类型:
无
示例
>>> import jax >>> ref = jax.new_ref(jax.numpy.arange(5)) >>> jax.ref.addupdate(ref, 2, 10) >>> ref Ref([ 0, 1, 12, 3, 4], dtype=int32)
通过索引语法实现等效操作
>>> ref = jax.new_ref(jax.numpy.arange(5)) >>> ref[2] += 10 >>> ref Ref([ 0, 1, 12, 3, 4], dtype=int32)
使用
...
向标量 Ref 添加值>>> ref = jax.new_ref(jax.numpy.int32(2)) >>> ref[...] += 10 >>> ref Ref(12, dtype=int32)