jax.make_array_from_single_device_arrays#

jax.make_array_from_single_device_arrays(shape, sharding, arrays, *, dtype=None)[source]#
从单个设备数组序列中返回一个 jax.Array

输入 sharding 的网格中的每个设备都必须在 arrays 中包含一个数组。

参数:
  • shape (Shape) — 输出 jax.Array 的形状。此参数传递的信息已包含在 shardingarrays 中,此处作为双重检查。

  • sharding (Sharding) — 分片:一个全局的 Sharding 实例,描述输出 jax.Array 如何在设备上分布。

  • arrays (Sequence[basearray.Array]) — listtuple 类型的 jax.Array,其中每个都可由单个设备寻址。len(arrays) 必须等于 len(sharding.addressable_devices),并且每个数组的形状必须相同。对于多进程代码,每个进程将使用与其数据对应的不同 arrays 参数进行调用。这些数组通常通过 jax.device_put 创建。

  • dtype (DTypeLike | None) — 输出 jax.Array 的 dtype。如果未提供,则使用 arrays 中第一个数组的 dtype。如果 arrays 为空,则必须提供 dtype 参数。

返回:

一个全局的 jax.Array,按照 sharding 分片,形状等于 shape,并且每个设备上的内容

arrays 匹配。

返回类型:

ArrayImpl

示例

>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> mesh_rows = 2
>>> mesh_cols =  jax.device_count() // 2
...
>>> global_shape = (8, 8)
>>> mesh = Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
>>> inp_data = np.arange(math.prod(global_shape)).reshape(global_shape)
...
>>> arrays = [
...    jax.device_put(inp_data[index], d)
...        for d, index in sharding.addressable_devices_indices_map(global_shape).items()]
...
>>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays)
>>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()

如果您有一个本地数组并希望将其转换为全局 jax.Array,请使用 jax.make_array_from_process_local_data