jax.numpy.ndim#
- jax.numpy.ndim(a)[源代码]#
返回数组的维度数。
JAX 实现的
numpy.ndim()
。与np.ndim
不同,如果输入是列表或元组等集合,此函数会引发TypeError
。- 参数:
a (ArrayLike | SupportsNdim) – 类数组对象,或任何具有
ndim
属性的对象。- 返回值:
一个整数,指定
a
的维度数。- 返回类型:
示例
数组的维度数
>>> x = jnp.arange(10) >>> jnp.ndim(x) 1 >>> y = jnp.ones((2, 3)) >>> jnp.ndim(y) 2
这也适用于标量
>>> jnp.ndim(3.14) 0
对于数组,也可以通过
jax.Array.ndim
属性访问>>> x.ndim 1