跳到主要内容
jax.experimental.pallas.swap
jax.experimental.pallas.swap
-
jax.experimental.pallas.swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, _function_name='swap')[源代码]
交换给定索引处的值并返回旧值。
有关参数的含义,请参见load()
。
- 返回值:
交换之前存储在 ref 中的值。
- 返回类型:
jax.Array