jax.numpy.size#

jax.numpy.size(a, axis=None)[source]#

返回给定轴上的元素数量。

JAX 实现的 numpy.size()。与 np.size 不同,如果输入是列表或元组之类的集合,此函数会引发 TypeError

参数:
  • a (ArrayLike | SupportsSize | SupportsShape) – 类数组对象,或在未指定 axis 时具有 size 属性的任何对象,或在指定 axis 时具有 shape 属性的对象。

  • axis (int | Sequence[int] | None) – 可选的整数或整数序列,指示沿哪个或哪些轴计算元素。 None(默认)返回元素总数。

返回:

一个整数,指定 a 中的元素数量。

返回类型:

int

示例

数组的大小

>>> x = jnp.arange(10)
>>> jnp.size(x)
10
>>> y = jnp.ones((2, 3))
>>> jnp.size(y)
6
>>> jnp.size(y, axis=1)
3
>>> jnp.size(y, axis=(1,))
3
>>> jnp.size(y, axis=(0, 1))
6

这也适用于标量

>>> jnp.size(3.14)
1

对于数组,也可以通过 jax.Array.size 属性访问

>>> y.size
6