jax.lax.all_gather#
- jax.lax.all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False)[源代码]#
跨所有副本收集 x 的值。
如果
x
是一个 pytree,则结果等效于将此函数映射到树中的每个叶子。这等效于 all_to_all(broadcast(x)),但速度更快。
- 参数:
x – 具有名为
axis_name
的映射轴的数组。axis_name – 可哈希的 Python 对象,用于命名 pmapped 轴(更多详细信息请参阅
jax.pmap()
文档)。axis_index_groups – 可选的列表的列表,包含轴索引(例如,对于大小为 4 的轴,[[0, 1], [2, 3]] 将在前两个和后两个副本上运行 all-gather)。分组必须恰好覆盖所有轴索引一次,并且所有组的大小必须相同。
axis – 一个位置轴,沿
axis_name
的块将在此轴上连接。tiled – 当
False
时,块将堆叠到输出中索引axis
的一个新位置轴中。 当True
时,axis
必须引用一个现有的位置维度,并且块将连接到该维度中。
- 返回值:
表示沿轴
axis_name
进行 all-gather 的结果的数组。形状与x.shape
相同,但当
tiled
为False
时,在位置axis
中有一个新的维度,其大小等于轴axis_name
的大小,当
tiled
为True
时,位置axis
中的维度大小乘以轴axis_name
的大小。
例如,有 4 个 XLA 设备可用
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]]
使用 axis_index_groups 的示例,组按偶数和奇数设备 ID 分割
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> def f(x): ... return jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]] [[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]]]