jax.ref.swap#
- jax.ref.swap(ref, idx, value, _function_name='ref_swap')[源代码]#
在原地更新数组值,同时返回旧值。
这等同于
ref[idx], prev = value, ref[idx]
,然后返回prev
,其中idx
是 NumPy 风格的索引器。有关可变数组引用的更多信息,请参阅 Ref 指南。- 参数:
ref (AbstractRef | TransformedRef) – 一个
jax.ref.Ref
对象。返回时,缓冲区将被此操作修改。idx (Indexer | tuple[Indexer, ...] | None) – 一个 NumPy 风格的索引器
value (Array) – 一个
jax.Array
对象(注意,不是jax.ref.Ref
),包含要设置在数组中的值。_function_name (str)
- 返回:
一个
jax.Array
,包含 idx 处的前一个值。- 返回类型:
示例
>>> import jax >>> ref = jax.new_ref(jax.numpy.arange(5)) >>> jax.ref.swap(ref, 3, 10) Array(3, dtype=int32) >>> ref Ref([ 0, 1, 2, 10, 4], dtype=int32)
通过索引语法实现等效操作
>>> ref = jax.new_ref(jax.numpy.arange(5)) >>> ref[3], prev = 10, ref[3] >>> prev Array(3, dtype=int32) >>> ref Ref([ 0, 1, 2, 10, 4], dtype=int32)
使用
...
来交换标量引用的值>>> ref = jax.new_ref(jax.numpy.int32(5)) >>> jax.ref.swap(ref, ..., 10) Array(5, dtype=int32) >>> ref Ref(10, dtype=int32)