jax.numpy.count_nonzero#
- jax.numpy.count_nonzero(a, axis=None, keepdims=False)[源代码]#
计算给定轴上的非零元素的数量。
JAX 对
numpy.count_nonzero()的实现。- 参数:
a (ArrayLike) – 输入数组。
axis (轴) – 可选,int 或 int 序列,默认为 None。计算非零个数的轴。如果为 None,则在扁平化数组中进行计数。
keepdims (bool) – bool,默认为 False。如果为 True,则保留大小为 1 的约简轴。
- 返回:
包含输入数组指定轴上非零元素数量的数组。
- 返回类型:
示例
默认情况下,
jnp.count_nonzero会沿所有轴计算非零值。>>> x = jnp.array([[1, 0, 0, 0], ... [0, 0, 1, 0], ... [1, 1, 1, 0]]) >>> jnp.count_nonzero(x) Array(5, dtype=int32)
如果
axis=1,则沿轴 1 计算。>>> jnp.count_nonzero(x, axis=1) Array([1, 1, 3], dtype=int32)
要保留输入的维度,可以设置
keepdims=True。>>> jnp.count_nonzero(x, axis=1, keepdims=True) Array([[1], [1], [3]], dtype=int32)