jax.ref.addupdate#
- jax.ref.addupdate(ref, idx, x)[源代码]#
就地向 Ref 中的元素添加值。
这类似于 NumPy 数组
ref和 NumPy 风格的索引器idx的ref[idx] += value。然而,对于 Refref,执行ref[idx] += value实际上会执行一次ref_get、加法和ref_set,因此在使用自动微分时,使用此函数可能更有效。有关可变数组 Ref 的更多信息,请参阅 Ref 指南。- 参数:
ref (AbstractRef) – 一个
jax.ref.Ref对象。返回时,缓冲区将被此操作修改。idx (Indexer | tuple[Indexer, ...] | None) – 一个 NumPy 风格的索引器
x (Array) – 一个
jax.Array对象(注意,不是jax.ref.Ref),其中包含要在指定索引处添加的值。
- 返回:
无
- 返回类型:
无
示例
>>> import jax >>> ref = jax.new_ref(jax.numpy.arange(5)) >>> jax.ref.addupdate(ref, 2, 10) >>> ref Ref([ 0, 1, 12, 3, 4], dtype=int32)
通过索引语法实现等效操作
>>> ref = jax.new_ref(jax.numpy.arange(5)) >>> ref[2] += 10 >>> ref Ref([ 0, 1, 12, 3, 4], dtype=int32)
使用
...向标量 Ref 添加值>>> ref = jax.new_ref(jax.numpy.int32(2)) >>> ref[...] += 10 >>> ref Ref(12, dtype=int32)