jax.numpy.shape#

jax.numpy.shape(a)[源代码]#

返回数组的形状。

numpy.shape()的 JAX 实现。与np.shape不同,如果输入是列表或元组等集合,则此函数会引发TypeError

参数:

a (类数组 | SupportsShape) – 类数组对象,或具有shape属性的任何对象。

返回:

一个表示 a 形状的整数元组。

返回类型:

tuple[int, …]

示例

数组的形状

>>> x = jnp.arange(10)
>>> jnp.shape(x)
(10,)
>>> y = jnp.ones((2, 3))
>>> jnp.shape(y)
(2, 3)

这也适用于标量

>>> jnp.shape(3.14)
()

对于数组,也可以通过jax.Array.shape属性访问

>>> x.shape
(10,)