jax.sharding 模块#
类#
- class jax.sharding.Sharding(*args, **kwargs)#
描述一个
jax.Array如何跨设备布局。- addressable_devices_indices_map(global_shape)[source]#
从可寻址设备到每个设备包含的数组切片(slice)的映射。
addressable_devices_indices_map包含适用于可寻址设备的device_indices_map的部分。- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index | None]
- property device_set: set[Device][source]#
此
Sharding所跨越的设备集合。在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。
- devices_indices_map(global_shape)[source]#
返回一个从设备到每个设备包含的数组切片(slice)的映射。
该映射包括所有全局设备,即包括来自其他进程的不可寻址设备。
- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index]
- property is_fully_addressable: bool[source]#
此分片是否完全可寻址?
如果当前进程可以寻址
Sharding中命名的所有设备,则该分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable等同于“is_local”。
- shard_shape(global_shape)[source]#
返回每个设备上数据的形状。
此函数返回的分片形状是根据
global_shape和分片的属性计算得出的。- 参数:
global_shape (Shape)
- 返回类型:
Shape
- class jax.sharding.SingleDeviceSharding(*args, **kwargs)#
Bases:
Sharding一个将数据放置在单个设备上的
Sharding。- 参数:
device – 一个单独的
Device。
示例
>>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0])
- property device_set: set[Device][source]#
此
Sharding所跨越的设备集合。在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。
- devices_indices_map(global_shape)[source]#
返回一个从设备到每个设备包含的数组切片(slice)的映射。
该映射包括所有全局设备,即包括来自其他进程的不可寻址设备。
- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index]
- property is_fully_addressable: bool[source]#
此分片是否完全可寻址?
如果当前进程可以寻址
Sharding中命名的所有设备,则该分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable等同于“is_local”。
- class jax.sharding.NamedSharding(*args, **kwargs)#
Bases:
ShardingNamedSharding使用命名轴(named axes)表达分片。一个
NamedSharding是一个设备Mesh和PartitionSpec的组合,它描述了如何在该 mesh 上分片一个数组。一个
Mesh是一个 JAX 设备的(多维)NumPy 数组,其中 mesh 的每个轴都有一个名称,例如'x'或'y'。一个
PartitionSpec是一个元组,其元素可以是None、一个字符串或一个字符串元组。每个元素描述一个输入维度如何跨零个或多个 mesh 维度进行分区。例如,PartitionSpec('x', 'y')表示数据的第一个维度在 mesh 的x轴上分片,第二个维度在 mesh 的y轴上分片。有关
Mesh和PartitionSpec用法的更多详细信息和图示,请参阅 分布式数组和自动并行化 和 显式分片 教程。- 参数:
mesh – 一个
jax.sharding.Mesh对象。spec – 一个
jax.sharding.PartitionSpec对象。
示例
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
- property device_set: set[Device][source]#
此
Sharding所跨越的设备集合。在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。
- property is_fully_addressable: bool[source]#
此分片是否完全可寻址?
如果当前进程可以寻址
Sharding中命名的所有设备,则该分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable等同于“is_local”。
- property mesh#
(self) -> object
- property spec#
(self) -> jax::PartitionSpec
- class jax.sharding.PmapSharding(*args, **kwargs)#
Bases:
Sharding描述
jax.pmap()使用的分片。- classmethod default(shape, sharded_dim=0, devices=None)[source]#
创建一个
PmapSharding,它匹配jax.pmap()的默认放置。- 参数:
shape (Shape) – 输入数组的形状。
sharded_dim (int | None) – 输入数组分片的维度。默认为 0。
devices (Sequence[xc.Device] | None) – 可选的设备序列。如果省略,则使用 pmap 的隐式设备顺序,即
jax.local_devices()的顺序。
- 返回类型:
- property devices#
(self) -> numpy.ndarray
- devices_indices_map(global_shape)[source]#
返回一个从设备到每个设备包含的数组切片(slice)的映射。
该映射包括所有全局设备,即包括来自其他进程的不可寻址设备。
- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index]
- is_equivalent_to(other, ndim)[source]#
当两个分片等价时返回
True。两个分片是等价的,如果它们将相同的逻辑数组分片放置在相同的设备上。
- 参数:
self (PmapSharding)
other (PmapSharding)
ndim (int)
- 返回类型:
- property is_fully_addressable: bool#
此分片是否完全可寻址?
如果当前进程可以寻址
Sharding中命名的所有设备,则该分片是完全可寻址的。在多进程 JAX 中,is_fully_addressable等同于“is_local”。
- shard_shape(global_shape)[source]#
返回每个设备上数据的形状。
此函数返回的分片形状是根据
global_shape和分片的属性计算得出的。- 参数:
global_shape (Shape)
- 返回类型:
Shape
- property sharding_spec#
(self) -> jax::ShardingSpec
- class jax.sharding.PartitionSpec(*args, **kwargs)#
描述如何跨设备 mesh 对数组进行分区的元组。
每个元素可以是
None、字符串或字符串元组。有关更多详细信息,请参阅jax.sharding.NamedSharding的文档。此类的存在是为了让 JAX 的 pytree 工具能够区分分区规范和应视为 pytree 的元组。
- property reduced#
(self) -> frozenset
- property unreduced#
(self) -> frozenset
- class jax.sharding.Mesh(devices, axis_names, axis_types=None)[source]#
声明此管理器作用域内可用的硬件资源。
请参阅 分布式数组和自动并行化 和 显式分片 教程。
- 参数:
devices (np.ndarray) – 一个包含 JAX 设备对象的 NumPy ndarray 对象(例如从
jax.devices()获取)。axis_names (tuple[MeshAxisName, ...]) – 分配给
devices参数维度的资源轴名称序列。其长度应与devices的秩匹配。axis_types (tuple[AxisType, ...]) – 一个可选的
jax.sharding.AxisType条目元组,对应于axis_names。有关更多信息,请参阅 显式分片。
示例
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P, NamedSharding >>> import numpy as np ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> devices = np.array(jax.devices()).reshape(4, 2) >>> mesh = Mesh(devices, ('x', 'y')) >>> inp = np.arange(16).reshape(8, 2) >>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y'))) >>> out = jax.jit(lambda x: x * 2)(arr) >>> assert out.sharding == NamedSharding(mesh, P('x', 'y'))