jax.numpy.block#

jax.numpy.block(arrays)[源码]#

从块列表创建数组。

JAX 实现 numpy.block()

参数:

arrays (ArrayLike | list[ArrayLike]) – 一个数组,或嵌套的数组列表,它们将被连接起来形成最终数组。

返回:

一个由输入构建的单个数组。

返回类型:

Array

另请参阅

示例

考虑这些块

>>> zeros = jnp.zeros((2, 2))
>>> ones = jnp.ones((2, 2))
>>> twos = jnp.full((2, 2), 2)
>>> threes = jnp.full((2, 2), 3)

将单个数组传递给 block() 会返回该数组

>>> jnp.block(zeros)
Array([[0., 0.],
       [0., 0.]], dtype=float32)

传递一个简单的数组列表会沿最后一个轴将它们连接起来

>>> jnp.block([zeros, ones])
Array([[0., 0., 1., 1.],
       [0., 0., 1., 1.]], dtype=float32)

传递一个双重嵌套的数组列表会沿最后一个轴连接内层列表,并沿倒数第二个轴连接外层列表

>>> jnp.block([[zeros, ones],
...            [twos, threes]])
Array([[0., 0., 1., 1.],
       [0., 0., 1., 1.],
       [2., 2., 3., 3.],
       [2., 2., 3., 3.]], dtype=float32)

请注意,块不必在所有维度上对齐,尽管沿连接轴的大小必须匹配。例如,这是有效的,因为在内部水平连接后,所得的块对于外部垂直连接具有有效的形状。

>>> a = jnp.zeros((2, 1))
>>> b = jnp.ones((2, 3))
>>> c = jnp.full((1, 2), 2)
>>> d = jnp.full((1, 2), 3)
>>> jnp.block([[a, b], [c, d]])
Array([[0., 1., 1., 1.],
       [0., 1., 1., 1.],
       [2., 2., 3., 3.]], dtype=float32)

另请注意,此逻辑可推广到 3 维或更多维的块。这是一个 3 维块状数组

>>> x = jnp.arange(6).reshape((1, 2, 3))
>>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)]
>>> jnp.block(blocks).shape
(5, 8, 9)