jax.lax.all_gather#
- jax.lax.all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False)[source]#
跨所有副本收集 x 的值。
如果
x
是一个 pytree,那么结果等同于将此函数映射到树中的每个叶子节点。这等同于 all_to_all(broadcast(x)),但速度更快。
- 参数:
x – 具有名为
axis_name
的映射轴的数组。axis_name – 可哈希的 Python 对象,用于命名 pmapped 轴(有关更多详细信息,请参阅
jax.pmap()
文档)。axis_index_groups – 可选的列表列表,包含轴索引(例如,对于大小为 4 的轴,[[0, 1], [2, 3]] 将在前两个和后两个副本上运行 all gather)。组必须精确覆盖所有轴索引一次,并且所有组的大小必须相同。
axis – 一个位置轴,沿
axis_name
的块将连接到其中。tiled – 当为
False
时,块将被堆叠到输出中索引为axis
的新位置轴中。当为True
时,axis
必须引用现有的位置维度,并且块将被连接到该维度中。
- 返回值:
表示沿轴
axis_name
进行 all-gather 的结果的数组。形状与x.shape
相同,但当
tiled
为False
时,在位置axis
中会有一个等于轴axis_name
大小的新维度,当
tiled
为True
时,位置axis
中的维度大小乘以轴axis_name
的大小。
例如,在有 4 个 XLA 设备可用的情况下
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]]
使用 axis_index_groups 的示例,按偶数和奇数设备 ID 分组
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> def f(x): ... return jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]] [[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]]]