jax.numpy.shape#
- jax.numpy.shape(a)[源代码]#
返回数组的形状。
numpy.shape()
的 JAX 实现。与np.shape
不同,如果输入是列表或元组等集合,则此函数会引发TypeError
。示例
数组的形状
>>> x = jnp.arange(10) >>> jnp.shape(x) (10,) >>> y = jnp.ones((2, 3)) >>> jnp.shape(y) (2, 3)
这也适用于标量
>>> jnp.shape(3.14) ()
对于数组,也可以通过
jax.Array.shape
属性访问>>> x.shape (10,)