jax.numpy.atleast_1d#

jax.numpy.atleast_1d(*arys)[源]#

将输入转换为至少一维的数组。

JAX 对 numpy.atleast_1d() 的实现。

参数:
  • 参数。 (零个多个类数组对象)

  • arys (类数组对象)

返回:

与输入值对应的数组或数组列表。形状为 () 的数组会转换为形状为 (1,) 的数组,其他形状的数组则保持不变。

返回类型:

数组 | 列表[数组]

示例

标量参数将转换为一维、长度为1的数组

>>> x = jnp.float32(1.0)
>>> jnp.atleast_1d(x)
Array([1.], dtype=float32)

更高维度的输入保持不变

>>> y = jnp.arange(4)
>>> jnp.atleast_1d(y)
Array([0, 1, 2, 3], dtype=int32)

可以一次向函数传递多个参数,在这种情况下,将返回结果列表

>>> jnp.atleast_1d(x, y)
[Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]