jax.numpy.size#

jax.numpy.size(a, axis=None)[源代码]#

返回沿给定轴的元素数量。

JAX 对 numpy.size() 的实现。与 np.size 不同,如果输入是集合(例如列表或元组),此函数会引发 TypeError

参数:
  • a (ArrayLike | SupportsSize | SupportsShape) – 类似数组的对象,或当未指定 axis 时具有 size 属性的任何对象,或当指定 axis 时具有 shape 属性的对象。

  • axis (int | None | None) – 可选整数,沿该整数计算元素。默认情况下,返回元素总数。

返回:

一个整数,指定 a 中的元素数量。

返回类型:

int

示例

数组的大小

>>> x = jnp.arange(10)
>>> jnp.size(x)
10
>>> y = jnp.ones((2, 3))
>>> jnp.size(y)
6
>>> jnp.size(y, axis=1)
3

这也适用于标量

>>> jnp.size(3.14)
1

对于数组,也可以通过 jax.Array.size 属性访问

>>> y.size
6