jax.ref.swap#

jax.ref.swap(ref, idx, value, _function_name='ref_swap')[源代码]#

在原地更新数组值,同时返回旧值。

这等同于 ref[idx], prev = value, ref[idx],然后返回 prev,其中 idx 是 NumPy 风格的索引器。有关可变数组引用的更多信息,请参阅 Ref 指南

参数:
  • ref (AbstractRef | TransformedRef) – 一个 jax.ref.Ref 对象。返回时,缓冲区将被此操作修改。

  • idx (Indexer | tuple[Indexer, ...] | None) – 一个 NumPy 风格的索引器

  • value (Array) – 一个 jax.Array 对象(注意,不是 jax.ref.Ref),包含要设置在数组中的值。

  • _function_name (str)

返回:

一个 jax.Array,包含 idx 处的前一个值。

返回类型:

Array

示例

>>> import jax
>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> jax.ref.swap(ref, 3, 10)
Array(3, dtype=int32)
>>> ref
Ref([ 0,  1,  2, 10,  4], dtype=int32)

通过索引语法实现等效操作

>>> ref = jax.new_ref(jax.numpy.arange(5))
>>> ref[3], prev = 10, ref[3]
>>> prev
Array(3, dtype=int32)
>>> ref
Ref([ 0,  1,  2, 10,  4], dtype=int32)

使用 ... 来交换标量引用的值

>>> ref = jax.new_ref(jax.numpy.int32(5))
>>> jax.ref.swap(ref, ..., 10)
Array(5, dtype=int32)
>>> ref
Ref(10, dtype=int32)