jax.numpy.block#
- jax.numpy.block(arrays)[source]#
从块列表中创建数组。
JAX 对
numpy.block()
的实现。示例
考虑这些块
>>> zeros = jnp.zeros((2, 2)) >>> ones = jnp.ones((2, 2)) >>> twos = jnp.full((2, 2), 2) >>> threes = jnp.full((2, 2), 3)
将单个数组传递给
block()
会返回该数组>>> jnp.block(zeros) Array([[0., 0.], [0., 0.]], dtype=float32)
传递一个简单的数组列表会沿最后一个轴连接它们
>>> jnp.block([zeros, ones]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.]], dtype=float32)
传递一个双层嵌套的数组列表会将内部列表沿最后一个轴连接,并将外部列表沿倒数第二个轴连接
>>> jnp.block([[zeros, ones], ... [twos, threes]]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.], [2., 2., 3., 3.], [2., 2., 3., 3.]], dtype=float32)
请注意,块不必在所有维度上都对齐,但沿连接轴的大小必须匹配。例如,这是有效的,因为在内部水平连接之后,结果块对于外部垂直连接具有有效的形状。
>>> a = jnp.zeros((2, 1)) >>> b = jnp.ones((2, 3)) >>> c = jnp.full((1, 2), 2) >>> d = jnp.full((1, 2), 3) >>> jnp.block([[a, b], [c, d]]) Array([[0., 1., 1., 1.], [0., 1., 1., 1.], [2., 2., 3., 3.]], dtype=float32)
另请注意,此逻辑推广到 3 维或更多维的块。这是一个 3 维块式数组
>>> x = jnp.arange(6).reshape((1, 2, 3)) >>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)] >>> jnp.block(blocks).shape (5, 8, 9)