jax.experimental.multihost_utils.global_array_to_host_local_array#
- jax.experimental.multihost_utils.global_array_to_host_local_array(global_inputs, global_mesh, pspecs)[源代码]#
将全局 jax.Array 转换为主机本地 jax.Array。
您可以使用此函数过渡到 jax.Array。 将 jax.Array 与 pjit 一起使用具有与将 GDA 与 pjit 一起使用相同的语义,即 pjit 的所有 jax.Array 输入都应为全局形状,并且来自 pjit 的输出也将是全局形状的 jax.Array
您可以使用此函数将来自 pjit 的全局形状 jax.Array 输出再次转换为主机本地值,以便向 jax.Array 的过渡可以是一个机械变化。
使用示例
>>> from jax.experimental import multihost_utils >>> >>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) >>> >>> with mesh: ... global_out = pjitted_fun(global_inputs) >>> >>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs)
- 参数:
global_inputs (Any) – 全局 jax.Array 的 Pytree。
global_mesh (jax.sharding.Mesh) –
jax.sharding.Mesh
对象。 网格必须是连续的,这意味着主机的所有本地设备必须形成一个子立方体。pspecs (Any) –
jax.sharding.PartitionSpec
对象的 Pytree。
- 返回:
主机本地数组的 Pytree。