jax.sharding 模块#

#

class jax.sharding.Sharding(*args, **kwargs)#

描述 jax.Array 如何在设备上分布布局。

property addressable_devices: set[Device]#

当前进程可寻址的 Sharding 中的设备集合。

addressable_devices_indices_map(global_shape)[source]#

从可寻址设备到其各自包含的数组数据切片的映射。

addressable_devices_indices_map 包含 device_indices_map 中适用于可寻址设备的部分。

参数:

global_shape (Shape)

返回类型:

Mapping[Device, Index | None]

property device_set: set[Device][source]#

Sharding 所跨越的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

devices_indices_map(global_shape)[source]#

返回从设备到其各自包含的数组切片的映射。

该映射包含所有全局设备,即包括来自其他进程的不可寻址设备。

参数:

global_shape (Shape)

返回类型:

Mapping[Device, Index]

is_equivalent_to(other, ndim)[source]#

如果两个分片方式等价,则返回 True

如果两个分片方式将相同的逻辑数组分片放置在相同的设备上,则它们是等价的。

参数:
返回类型:

bool

property is_fully_addressable: bool[source]#

此分片方式是否完全可寻址?

如果当前进程可以寻址 Sharding 中命名的所有设备,则该分片方式是完全可寻址的。is_fully_addressable 在多进程 JAX 中等同于“is_local”。

property is_fully_replicated: bool[source]#

此分片方式是否完全复制?

如果每个设备都拥有整个数据的完整副本,则该分片方式是完全复制的。

property memory_kind: str | None[source]#

返回该分片方式的内存类型。

property num_devices: int[source]#

该分片方式包含的设备数量。

shard_shape(global_shape)[source]#

返回每个设备上的数据形状。

此函数返回的分片形状由 global_shape 和该分片方式的属性计算得出。

参数:

global_shape (Shape)

返回类型:

Shape

with_memory_kind(kind)[source]#

返回具有指定内存类型的新 Sharding 实例。

参数:

kind (str)

返回类型:

Sharding(分片)

class jax.sharding.SingleDeviceSharding(*args, **kwargs)#

Bases: Sharding

一种将数据放置在单个设备上的 Sharding

参数:

device – 单个 Device

示例

>>> single_device_sharding = jax.sharding.SingleDeviceSharding(
...     jax.devices()[0])
property device_set: set[Device][source]#

Sharding 所跨越的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

devices_indices_map(global_shape)[source]#

返回从设备到其各自包含的数组切片的映射。

该映射包含所有全局设备,即包括来自其他进程的不可寻址设备。

参数:

global_shape (Shape)

返回类型:

Mapping[Device, Index]

property is_fully_addressable: bool[source]#

此分片方式是否完全可寻址?

如果当前进程可以寻址 Sharding 中命名的所有设备,则该分片方式是完全可寻址的。is_fully_addressable 在多进程 JAX 中等同于“is_local”。

property is_fully_replicated: bool[source]#

此分片方式是否完全复制?

如果每个设备都拥有整个数据的完整副本,则该分片方式是完全复制的。

property memory_kind: str | None[source]#

返回该分片方式的内存类型。

property num_devices: int[source]#

该分片方式包含的设备数量。

with_memory_kind(kind)[source]#

返回具有指定内存类型的新 Sharding 实例。

参数:

kind (str)

返回类型:

SingleDeviceSharding

class jax.sharding.NamedSharding(*args, **kwargs)#

Bases: Sharding

NamedSharding 使用命名轴来表示分片。

NamedSharding 由一个设备 Mesh 和一个描述如何在该网格上切分数组的 PartitionSpec 组成。

Mesh 是一个 JAX 设备的多维 NumPy 数组,其中网格的每个轴都有一个名称,例如 'x''y'

PartitionSpec 是一个元组,其元素可以是 None、网格轴名称或网格轴名称的元组。每个元素描述了如何将输入维度切分到零个或多个网格维度上。例如,PartitionSpec('x', 'y') 表示数据的第一维沿着网格的 'x' 轴进行切分,第二维沿着网格的 'y' 轴进行切分。

分布式数组和自动并行化 以及 显式分片 教程包含更多详细信息和图表,解释了 MeshPartitionSpec 的用法。

参数:

示例

>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> spec = P('x', 'y')
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
property addressable_devices: set[Device]#

当前进程可寻址的 Sharding 中的设备集合。

property device_set: set[Device]#

Sharding 所跨越的设备集合。

在多控制器 JAX 中,设备集合是全局的,即包括来自其他进程的不可寻址设备。

is_equivalent_to(other, ndim)#

如果两个分片方式等价,则返回 True

如果两个分片方式将相同的逻辑数组分片放置在相同的设备上,则它们是等价的。

参数:

ndim (int)

返回类型:

bool

property is_fully_addressable: bool#

此分片方式是否完全可寻址?

如果当前进程可以寻址 Sharding 中命名的所有设备,则该分片方式是完全可寻址的。is_fully_addressable 在多进程 JAX 中等同于“is_local”。

property is_fully_replicated: bool#

此分片方式是否完全复制?

如果每个设备都拥有整个数据的完整副本,则该分片方式是完全复制的。

property memory_kind: str | None#

返回该分片方式的内存类型。

property mesh#

(self) -> object

property num_devices: int#

该分片方式包含的设备数量。

property spec#

(self) -> jax::PartitionSpec

with_memory_kind(kind)#

返回具有指定内存类型的新 Sharding 实例。

参数:

kind (str)

返回类型:

NamedSharding

class jax.sharding.PartitionSpec(*args, **kwargs)#

描述如何跨设备网格切分数组的元组。

每个元素要么是 None,要么是一个字符串,或者是一个字符串元组。更多详细信息请参阅 jax.sharding.NamedSharding 的文档。

引入该类的目的是为了让 JAX 的 pytree 工具能够区分分片规范与应该被视为 pytree 的元组。

property reduced#

(self) -> frozenset

property unreduced#

(self) -> frozenset

class jax.sharding.Mesh(devices, axis_names, axis_types=None)#

声明在此管理器范围内可用的硬件资源。

请参阅 分布式数组和自动并行化 以及 显式分片 教程。

参数:
  • devices (np.ndarray) – 包含 JAX 设备对象的 NumPy ndarray 对象(例如通过 jax.devices() 获取)。

  • axis_names (tuple[MeshAxisName, ...]) – 要分配给 devices 参数维度的资源轴名称序列。其长度应与 devices 的秩相匹配。

  • axis_types (tuple[AxisType, ...]) – 可选的 jax.sharding.AxisType 条目元组,对应于 axis_names。更多信息请参阅 显式分片

示例

>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P, NamedSharding
>>> import numpy as np
...
>>> # Declare a 2D mesh with axes `x` and `y`.
>>> devices = np.array(jax.devices()).reshape(4, 2)
>>> mesh = Mesh(devices, ('x', 'y'))
>>> inp = np.arange(16).reshape(8, 2)
>>> arr = jax.device_put(inp, NamedSharding(mesh, P('x', 'y')))
>>> out = jax.jit(lambda x: x * 2)(arr)
>>> assert out.sharding == NamedSharding(mesh, P('x', 'y'))