jax.numpy.size#
- jax.numpy.size(a, axis=None)[源代码]#
返回沿给定轴的元素数量。
JAX 对
numpy.size()
的实现。与np.size
不同,如果输入是集合(例如列表或元组),此函数会引发TypeError
。- 参数:
a (ArrayLike | SupportsSize | SupportsShape) – 类似数组的对象,或当未指定
axis
时具有size
属性的任何对象,或当指定axis
时具有shape
属性的对象。axis (int | None | None) – 可选整数,沿该整数计算元素。默认情况下,返回元素总数。
- 返回:
一个整数,指定
a
中的元素数量。- 返回类型:
示例
数组的大小
>>> x = jnp.arange(10) >>> jnp.size(x) 10 >>> y = jnp.ones((2, 3)) >>> jnp.size(y) 6 >>> jnp.size(y, axis=1) 3
这也适用于标量
>>> jnp.size(3.14) 1
对于数组,也可以通过
jax.Array.size
属性访问>>> y.size 6