jax.numpy.atleast_3d#
- jax.numpy.atleast_3d(*arys)[源代码]#
将输入转换为至少具有 3 个维度的数组。
JAX 对
numpy.atleast_3d()的实现。- 参数:
参数。 (零个或多个类数组对象)
arys (类数组对象)
- 返回:
一个数组或数组列表,对应于输入值。形状为
()的数组转换为形状(1, 1, 1),形状为(N,)的一维数组转换为形状(1, N, 1),形状为(M, N)的二维数组转换为形状(M, N, 1),所有其他形状的数组将保持不变。- 返回类型:
示例
标量参数被转换为大小为 1 的 3D 数组
>>> x = jnp.float32(1.0) >>> jnp.atleast_3d(x) Array([[[1.]]], dtype=float32)
一维数组在开头和结尾添加了一个单位维度
>>> y = jnp.arange(4) >>> jnp.atleast_3d(y).shape (1, 4, 1)
二维数组在末尾添加了一个单位维度
>>> z = jnp.ones((2, 3)) >>> jnp.atleast_3d(z).shape (2, 3, 1)
可以一次向函数传递多个参数,在这种情况下,将返回结果列表
>>> x3, y3 = jnp.atleast_3d(x, y) >>> print(x3) [[[1.]]] >>> print(y3) [[[0] [1] [2] [3]]]