jax.numpy.count_nonzero#
- jax.numpy.count_nonzero(a, axis=None, keepdims=False)[源文件]#
返回沿给定轴的非零元素的数量。
JAX 对
numpy.count_nonzero()
的实现。- 参数:
a (ArrayLike) – 输入数组。
axis (轴) – 可选,整数或整数序列,默认值为 None。计算非零元素数量的轴。如果为 None,则计算扁平化数组中的非零元素数量。
keepdims (bool) – 布尔值,默认为 False。如果为 true,则在结果中保留大小为 1 的缩减轴。
- 返回:
一个数组,其中包含输入沿指定轴的非零元素数量。
- 返回类型:
示例
默认情况下,
jnp.count_nonzero
计算沿所有轴的非零值。>>> x = jnp.array([[1, 0, 0, 0], ... [0, 0, 1, 0], ... [1, 1, 1, 0]]) >>> jnp.count_nonzero(x) Array(5, dtype=int32)
如果
axis=1
,则沿轴 1 计算。>>> jnp.count_nonzero(x, axis=1) Array([1, 1, 3], dtype=int32)
要保留输入的维度,可以设置
keepdims=True
。>>> jnp.count_nonzero(x, axis=1, keepdims=True) Array([[1], [1], [3]], dtype=int32)