jax.numpy.atleast_2d#
- jax.numpy.atleast_2d(*arys)[来源]#
将输入转换为至少具有 2 个维度的数组。
numpy.atleast_2d() 的 JAX 实现。
- 参数:
参数。 (零个或多个类数组对象)
arys (类数组对象)
- 返回:
与输入值对应的数组或数组列表。形状为
()
的数组被转换为形状为(1, 1)
的数组,形状为(N,)
的一维 (1D) 数组被转换为形状为(1, N)
的数组,所有其他形状的数组则保持不变。- 返回类型:
示例
标量参数被转换为二维(2D)、大小为 1 的数组
>>> x = jnp.float32(1.0) >>> jnp.atleast_2d(x) Array([[1.]], dtype=float32)
一维参数的形状前会添加一个单位维度
>>> y = jnp.arange(4) >>> jnp.atleast_2d(y) Array([[0, 1, 2, 3]], dtype=int32)
高维输入保持不变
>>> z = jnp.ones((2, 3)) >>> jnp.atleast_2d(z) Array([[1., 1., 1.], [1., 1., 1.]], dtype=float32)
可以一次向函数传递多个参数,在这种情况下,将返回结果列表
>>> jnp.atleast_2d(x, y) [Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]