jax.make_array_from_single_device_arrays#
- jax.make_array_from_single_device_arrays(shape, sharding, arrays, *, dtype=None)[source]#
- 从单个设备数组序列中返回一个
jax.Array
。 输入
sharding
的网格中的每个设备都必须在arrays
中包含一个数组。
- 参数:
shape (Shape) — 输出
jax.Array
的形状。此参数传递的信息已包含在sharding
和arrays
中,此处作为双重检查。sharding (Sharding) — 分片:一个全局的 Sharding 实例,描述输出 jax.Array 如何在设备上分布。
arrays (Sequence[basearray.Array]) — list 或 tuple 类型的
jax.Array
,其中每个都可由单个设备寻址。len(arrays)
必须等于len(sharding.addressable_devices)
,并且每个数组的形状必须相同。对于多进程代码,每个进程将使用与其数据对应的不同arrays
参数进行调用。这些数组通常通过jax.device_put
创建。dtype (DTypeLike | None) — 输出
jax.Array
的 dtype。如果未提供,则使用arrays
中第一个数组的 dtype。如果arrays
为空,则必须提供dtype
参数。
- 返回:
- 一个全局的
jax.Array
,按照sharding
分片,形状等于shape
,并且每个设备上的内容 与
arrays
匹配。
- 一个全局的
- 返回类型:
ArrayImpl
示例
>>> import math >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> mesh_rows = 2 >>> mesh_cols = jax.device_count() // 2 ... >>> global_shape = (8, 8) >>> mesh = Mesh(np.array(jax.devices()).reshape(mesh_rows, mesh_cols), ('x', 'y')) >>> sharding = jax.sharding.NamedSharding(mesh, P('x', 'y')) >>> inp_data = np.arange(math.prod(global_shape)).reshape(global_shape) ... >>> arrays = [ ... jax.device_put(inp_data[index], d) ... for d, index in sharding.addressable_devices_indices_map(global_shape).items()] ... >>> arr = jax.make_array_from_single_device_arrays(global_shape, sharding, arrays) >>> assert arr.shape == (8,8) # arr.shape is (8,8) regardless of jax.device_count()
如果您有一个本地数组并希望将其转换为全局 jax.Array,请使用
jax.make_array_from_process_local_data
。- 从单个设备数组序列中返回一个