jax.sharding 模块#
类#
- class jax.sharding.Sharding(*args, **kwargs)#
描述
jax.Array如何在设备上分布布局。- addressable_devices_indices_map(global_shape)[source]#
从可寻址设备到其各自包含的数组数据切片的映射。
addressable_devices_indices_map包含device_indices_map中适用于可寻址设备的部分。- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index | None]
- property device_set: set[Device][source]#
此
Sharding所跨越的设备集合。在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。
- devices_indices_map(global_shape)[source]#
返回从设备到其各自包含的数组切片的映射。
该映射包含所有全局设备,即包括来自其他进程的不可寻址设备。
- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index]
- property is_fully_addressable: bool[source]#
此分片方式是否完全可寻址?
如果当前进程可以寻址
Sharding中命名的所有设备,则该分片方式是完全可寻址的。is_fully_addressable在多进程 JAX 中等同于“is_local”。
- shard_shape(global_shape)[source]#
返回每个设备上的数据形状。
此函数返回的分片形状由
global_shape和该分片方式的属性计算得出。- 参数:
global_shape (Shape)
- 返回类型:
Shape
- class jax.sharding.SingleDeviceSharding(*args, **kwargs)#
Bases:
Sharding一种将数据放置在单个设备上的
Sharding。- 参数:
device – 单个
Device。
示例
>>> single_device_sharding = jax.sharding.SingleDeviceSharding( ... jax.devices()[0])
- property device_set: set[Device][source]#
此
Sharding所跨越的设备集合。在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。
- devices_indices_map(global_shape)[source]#
返回从设备到其各自包含的数组切片的映射。
该映射包含所有全局设备,即包括来自其他进程的不可寻址设备。
- 参数:
global_shape (Shape)
- 返回类型:
Mapping[Device, Index]
- property is_fully_addressable: bool[source]#
此分片方式是否完全可寻址?
如果当前进程可以寻址
Sharding中命名的所有设备,则该分片方式是完全可寻址的。is_fully_addressable在多进程 JAX 中等同于“is_local”。
- class jax.sharding.NamedSharding(*args, **kwargs)#
Bases:
ShardingNamedSharding使用命名轴来表示分片。NamedSharding由一个设备Mesh和一个描述如何在该网格上切分数组的PartitionSpec组成。Mesh是一个 JAX 设备的多维 NumPy 数组,其中网格的每个轴都有一个名称,例如'x'或'y'。PartitionSpec是一个元组,其元素可以是None、网格轴名称或网格轴名称的元组。每个元素描述了如何将输入维度切分到零个或多个网格维度上。例如,PartitionSpec('x', 'y')表示数据的第一维沿着网格的'x'轴进行切分,第二维沿着网格的'y'轴进行切分。分布式数组和自动并行化 以及 显式分片 教程包含更多详细信息和图表,解释了
Mesh和PartitionSpec的用法。- 参数:
mesh –
jax.sharding.Mesh对象。spec –
jax.sharding.PartitionSpec对象。memory_kind – 指示分片内存类型的字符串。
示例
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> spec = P('x', 'y') >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
- is_equivalent_to(other, ndim)#
如果两个分片方式等价,则返回
True。如果两个分片方式将相同的逻辑数组分片放置在相同的设备上,则它们是等价的。
- property is_fully_addressable: bool#
此分片方式是否完全可寻址?
如果当前进程可以寻址
Sharding中命名的所有设备,则该分片方式是完全可寻址的。is_fully_addressable在多进程 JAX 中等同于“is_local”。
- property mesh#
(self) -> object
- property spec#
(self) -> jax::PartitionSpec
- class jax.sharding.PartitionSpec(*args, **kwargs)#
描述如何跨设备网格切分数组的元组。
每个元素要么是
None,要么是一个字符串,或者是一个字符串元组。更多详细信息请参阅jax.sharding.NamedSharding的文档。引入该类的目的是为了让 JAX 的 pytree 工具能够区分分片规范与应该被视为 pytree 的元组。
- property reduced#
(self) -> frozenset
- property unreduced#
(self) -> frozenset
- class jax.sharding.Mesh(devices, axis_names, axis_types=None)#
声明在此管理器范围内可用的硬件资源。
请参阅 分布式数组和自动并行化 以及 显式分片 教程。
- 参数:
devices (np.ndarray) – 包含 JAX 设备对象的 NumPy ndarray 对象(例如通过
jax.devices()获取)。axis_names (tuple[MeshAxisName, ...]) – 要分配给
devices参数维度的资源轴名称序列。其长度应与devices的秩相匹配。axis_types (tuple[AxisType, ...]) – 可选的
jax.sharding.AxisType条目元组,对应于axis_names。更多信息请参阅 显式分片。
示例
>>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P, NamedSharding >>> import numpy as np ... >>> # Declare a 2D mesh with axes `x` and `y`. >>> devices = np.array(jax.devices()).reshape(4, 2) >>> mesh = Mesh(devices, ('x', 'y')) >>> inp = np.arange(16).reshape(8, 2) >>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y'))) >>> out = jax.jit(lambda x: x * 2)(arr) >>> assert out.sharding == NamedSharding(mesh, P('x', 'y'))