jax.numpy.size#
- jax.numpy.size(a, axis=None)[source]#
返回给定轴上的元素数量。
JAX 实现的
numpy.size()
。与np.size
不同,如果输入是列表或元组之类的集合,此函数会引发TypeError
。- 参数:
- 返回:
一个整数,指定
a
中的元素数量。- 返回类型:
示例
数组的大小
>>> x = jnp.arange(10) >>> jnp.size(x) 10 >>> y = jnp.ones((2, 3)) >>> jnp.size(y) 6 >>> jnp.size(y, axis=1) 3 >>> jnp.size(y, axis=(1,)) 3 >>> jnp.size(y, axis=(0, 1)) 6
这也适用于标量
>>> jnp.size(3.14) 1
对于数组,也可以通过
jax.Array.size
属性访问>>> y.size 6