jax.experimental.pallas.swap#

jax.experimental.pallas.swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, _function_name='swap')[源代码]#

交换给定索引处的值并返回旧值。

有关参数的含义,请参见load()

返回值:

交换之前存储在 ref 中的值。

返回类型:

jax.Array