jax.numpy.atleast_2d#

jax.numpy.atleast_2d(*arys)[来源]#

将输入转换为至少具有 2 个维度的数组。

numpy.atleast_2d() 的 JAX 实现。

参数:
  • 参数。 (零个多个类数组对象)

  • arys (类数组对象)

返回:

与输入值对应的数组或数组列表。形状为 () 的数组被转换为形状为 (1, 1) 的数组,形状为 (N,) 的一维 (1D) 数组被转换为形状为 (1, N) 的数组,所有其他形状的数组则保持不变。

返回类型:

Array | list[Array]

示例

标量参数被转换为二维(2D)、大小为 1 的数组

>>> x = jnp.float32(1.0)
>>> jnp.atleast_2d(x)
Array([[1.]], dtype=float32)

一维参数的形状前会添加一个单位维度

>>> y = jnp.arange(4)
>>> jnp.atleast_2d(y)
Array([[0, 1, 2, 3]], dtype=int32)

高维输入保持不变

>>> z = jnp.ones((2, 3))
>>> jnp.atleast_2d(z)
Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

可以一次向函数传递多个参数,在这种情况下,将返回结果列表

>>> jnp.atleast_2d(x, y)
[Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]