jax.numpy.shape#

jax.numpy.shape(a)[源代码]#

返回数组的形状。

JAX 实现的 numpy.shape()。与 np.shape 不同,如果输入是列表或元组等集合,此函数会引发 TypeError

参数:

a (ArrayLike | SupportsShape) – 类似数组的对象,或任何具有 shape 属性的对象。

返回:

表示 a 形状的整数元组。

返回类型:

tuple[int, …]

示例

数组的形状

>>> x = jnp.arange(10)
>>> jnp.shape(x)
(10,)
>>> y = jnp.ones((2, 3))
>>> jnp.shape(y)
(2, 3)

这也适用于标量

>>> jnp.shape(3.14)
()

对于数组,也可以通过 jax.Array.shape 属性访问

>>> x.shape
(10,)