jax.experimental.multihost_utils.host_local_array_to_global_array#
- jax.experimental.multihost_utils.host_local_array_to_global_array(local_inputs, global_mesh, pspecs)[源代码]#
将主机本地值转换为全局分片 jax.Array。
此函数接收主机本地数据(在不同主机上可能不同),并使用这些数据填充全局数组,其中每台设备在每台主机上都会根据 global_mesh/pspects 定义的分片获得适当的数据切片。
例如
>>> global_mesh = jax.sharding.Mesh(jax.devices(), 'x') >>> pspecs = jax.sharding.PartitionSpec('x') >>> host_id = jax.process_index() >>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # NB: assumes jax.local_device_count() divides 4.
生成的数组的形状将是 (4 * num_processes),并且分布式值为:(0, 1, 2, 3, 0, 2, 4, 6, 0, 3, 6, 9, …),其中每个切片 np.arange(4) * host_id 将跨相应主机的设备进行分区。
类似地
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()), ['host', 'dev']) >>> pspecs = jax.sharding.PartitionSpec('host') >>> host_id = jax.process_index() >>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs)
将创建相同的分布式值 (0, 1, 2, 3, 0, 2, 4, 6, …),但每个切片 np.arange(4) * i 将在相应的相应主机设备上复制。
另一方面,如果 pspecs = PartitionSpec(),表示跨所有轴复制,那么这个片段
>>> pspecs = jax.sharding.PartitionSpec() >>> arr = host_local_array_to_global_array(np.arange(4), mesh, pspecs)
将具有形状 (4,),值 (0, 1, 2, 3) 将在所有主机和设备上复制。
如果 local_inputs 不相同,且 pspec 指示数据复制,则行为未定义。
您可以使用此函数过渡到 jax.Array。将 pjit 与 jax.Array 结合使用与将 GDA 与 pjit 结合使用的语义相同,即所有传递给 pjit 的 jax.Array 输入都应是全局形状的。
如果您当前将主机本地值传递给 pjit,您可以使用此函数将您的主机本地值转换为全局数组,然后将其传递给 pjit。
示例用法。
>>> from jax.experimental import multihost_utils >>> >>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) >>> >>> with mesh: >>> global_out = pjitted_fun(global_inputs) >>> >>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs)
请注意,此函数需要全局 mesh 是连续 mesh,即属于每个主机的设备在此 mesh 中应形成一个子立方体。要将本地数据移动到非连续 mesh 的全局数组,请改用 jax.make_array_from_callback 或 jax.make_array_from_single_device_arrays。
- 参数:
local_inputs (Any) – 主机本地值的 Pytree。
global_mesh (jax.sharding.Mesh) – 一个 jax.sharding.Mesh 对象。该 mesh 必须是连续 mesh,
mesh. (即所有主机的设备必须在此 mesh 中形成一个子立方体。)
pspecs (Any) – jax.sharding.PartitionSpec 的 Pytree。
- 返回:
全局数组的 Pytree。