jax.numpy.block#
- jax.numpy.block(arrays)[源码]#
从块列表创建数组。
JAX 实现
numpy.block()。- 参数:
arrays (ArrayLike | list[ArrayLike]) – 一个数组,或嵌套的数组列表,它们将被连接起来形成最终数组。
- 返回:
一个由输入构建的单个数组。
- 返回类型:
示例
考虑这些块
>>> zeros = jnp.zeros((2, 2)) >>> ones = jnp.ones((2, 2)) >>> twos = jnp.full((2, 2), 2) >>> threes = jnp.full((2, 2), 3)
将单个数组传递给
block()会返回该数组>>> jnp.block(zeros) Array([[0., 0.], [0., 0.]], dtype=float32)
传递一个简单的数组列表会沿最后一个轴将它们连接起来
>>> jnp.block([zeros, ones]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.]], dtype=float32)
传递一个双重嵌套的数组列表会沿最后一个轴连接内层列表,并沿倒数第二个轴连接外层列表
>>> jnp.block([[zeros, ones], ... [twos, threes]]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.], [2., 2., 3., 3.], [2., 2., 3., 3.]], dtype=float32)
请注意,块不必在所有维度上对齐,尽管沿连接轴的大小必须匹配。例如,这是有效的,因为在内部水平连接后,所得的块对于外部垂直连接具有有效的形状。
>>> a = jnp.zeros((2, 1)) >>> b = jnp.ones((2, 3)) >>> c = jnp.full((1, 2), 2) >>> d = jnp.full((1, 2), 3) >>> jnp.block([[a, b], [c, d]]) Array([[0., 1., 1., 1.], [0., 1., 1., 1.], [2., 2., 3., 3.]], dtype=float32)
另请注意,此逻辑可推广到 3 维或更多维的块。这是一个 3 维块状数组
>>> x = jnp.arange(6).reshape((1, 2, 3)) >>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)] >>> jnp.block(blocks).shape (5, 8, 9)