jax.distributed 模块#

initialize([coordinator_address, ...])

初始化 JAX 分布式系统。

关闭()

关闭分布式系统。