jax.numpy.shape#
- jax.numpy.shape(a)[源代码]#
返回数组的形状。
JAX 对
numpy.shape()
的实现。与np.shape
不同,如果输入是列表或元组等集合,此函数将引发TypeError
。示例
数组的形状
>>> x = jnp.arange(10) >>> jnp.shape(x) (10,) >>> y = jnp.ones((2, 3)) >>> jnp.shape(y) (2, 3)
这也适用于标量
>>> jnp.shape(3.14) ()
对于数组,也可以通过
jax.Array.shape
属性进行访问。>>> x.shape (10,)