jax.experimental.pallas.load#

jax.experimental.pallas.load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None, eviction_policy=None, volatile=False)[源代码]#

返回从给定索引加载的数组。

如果既未指定 mask 也未指定 other,则此函数具有与 JAX 中 x_ref_or_view[idx] 相同的语义。

参数:
  • x_ref_or_view – 要从中加载的引用。

  • idx – 要使用的索引器。

  • mask – 一个可选的布尔掩码,指定要加载的索引。如果 mask 为 False 且未给出 other,则无法对结果数组中的值进行任何假设。

  • other – 一个可选值,用于 mask 为 False 的索引。

  • cache_modifier – 待记录。

  • eviction_policy – 待记录。

  • volatile – 待记录。

返回类型:

jax.Array