jax.experimental.pallas.run_scoped#
- jax.experimental.pallas.run_scoped(f, *types, collective_axes=(), **kw_types)[源]#
使用分配的引用调用函数并返回结果。
位置参数和关键字参数描述了为每个参数分配哪些引用类型。每个后端除了
jax.experimental.pallas.MemoryRef
之外,还有自己的一组引用类型。当指定 collective_axes 时,对于所有仅在集体轴上程序 ID 不同的程序,将返回相同的分配。如果沿该轴的所有程序中没有调用相同的 run_scoped,则会出错。
- 参数:
f (可调用对象[..., 任意类型])
types (任意类型)
collective_axes (可哈希 | 元组[可哈希, ...])
kw_types (任意类型)
- 返回类型:
任意类型