jax.experimental.pallas.swap#

jax.experimental.pallas.swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, _function_name='swap')[源代码]#

交换给定索引处的值,并返回旧值。

有关参数的含义,请参见load()

返回:

交换前 ref 中存储的值。

返回类型:

jax.Array