jax.numpy.ndim#

jax.numpy.ndim(a)[源代码]#

返回数组的维度数。

JAX 实现的 numpy.ndim()。与 np.ndim 不同,如果输入是列表或元组等集合,此函数会引发 TypeError

参数:

a (ArrayLike | SupportsNdim) – 类数组对象,或任何具有 ndim 属性的对象。

返回值:

一个整数,指定 a 的维度数。

返回类型:

int

示例

数组的维度数

>>> x = jnp.arange(10)
>>> jnp.ndim(x)
1
>>> y = jnp.ones((2, 3))
>>> jnp.ndim(y)
2

这也适用于标量

>>> jnp.ndim(3.14)
0

对于数组,也可以通过 jax.Array.ndim 属性访问

>>> x.ndim
1