jax.lax.split#

jax.lax.split(operand, sizes, axis=0)[源代码]#

沿着 axis 切割数组。

参数:
  • operand (ArrayLike) – 要切割的数组

  • sizes (Sequence[int]) – 切割后数组的大小。sizes 的总和必须等于 operandaxis 维度的尺寸。

  • axis (int) – 沿着哪个轴切割数组。

返回:

一个包含 len(sizes) 个数组的序列。如果 sizes[s1, s2, ...],则此函数沿着 axis 返回大小分别为 s1, s2, ... 的块。

返回类型:

Sequence[Array]