jax.numpy.shape#
- jax.numpy.shape(a)[源代码]#
返回数组的形状。
JAX 实现的
numpy.shape()
。与np.shape
不同,如果输入是列表或元组等集合,此函数会引发TypeError
。- 参数:
a (ArrayLike | SupportsShape) – 类似数组的对象,或任何具有
shape
属性的对象。- 返回:
表示
a
形状的整数元组。- 返回类型:
示例
数组的形状
>>> x = jnp.arange(10) >>> jnp.shape(x) (10,) >>> y = jnp.ones((2, 3)) >>> jnp.shape(y) (2, 3)
这也适用于标量
>>> jnp.shape(3.14) ()
对于数组,也可以通过
jax.Array.shape
属性访问>>> x.shape (10,)