jax.experimental.pallas.run_scoped#

jax.experimental.pallas.run_scoped(f, *types, collective_axes=(), **kw_types)[源]#

使用分配的引用调用函数并返回结果。

位置参数和关键字参数描述了为每个参数分配哪些引用类型。每个后端除了 jax.experimental.pallas.MemoryRef 之外,还有自己的一组引用类型。

当指定 collective_axes 时,对于所有仅在集体轴上程序 ID 不同的程序,将返回相同的分配。如果沿该轴的所有程序中没有调用相同的 run_scoped,则会出错。

参数:
  • f (可调用对象[..., 任意类型])

  • types (任意类型)

  • collective_axes (可哈希 | 元组[可哈希, ...])

  • kw_types (任意类型)

返回类型:

任意类型