jax.lax.split#

jax.lax.split(operand, sizes, axis=0)[来源]#

沿 axis 分割数组。

参数:
  • operand (ArrayLike) – 要分割的数组

  • sizes (Sequence[int]) – 分割数组的大小。大小之和必须等于 operandaxis 维度的大小。

  • axis (int) – 沿其分割数组的轴。

返回:

一个 len(sizes) 数组序列。如果 sizes[s1, s2, ...],则此函数返回沿 axis 提取的大小为 s1s2 的块。

返回类型:

Sequence[Array]