jax.numpy.nan_to_num#
- jax.numpy.nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[源]#
替换数组中的 NaN 和无穷大值。
JAX 对
numpy.nan_to_num()
的实现。- 参数:
x (ArrayLike) – 要替换的值的数组。如果它没有不精确的 dtype,则将原样返回。
copy (bool) – JAX 未使用
nan (ArrayLike) – 用于替代 NaN 条目的值。默认为 0.0。
posinf (ArrayLike | None) – 用于替代正无穷大条目的值。默认为最大可表示值。
neginf (ArrayLike | None) – 用于替代正无穷大条目的值。默认为最小可表示值。
- 返回:
替换了请求值的
x
的副本。- 返回类型:
另请参阅
jax.numpy.isnan()
: 当数组包含 NaN 时返回 Truejax.numpy.isposinf()
: 当数组包含 +inf 时返回 Truejax.numpy.isneginf()
: 当数组包含 -inf 时返回 True
示例
>>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf])
默认替换值
>>> jnp.nan_to_num(x) Array([ 0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 3.4028235e+38, 2.0000000e+00, -3.4028235e+38], dtype=float32)
覆盖 -inf 和 +inf 的替换值
>>> jnp.nan_to_num(x, posinf=999, neginf=-999) Array([ 0., 0., 1., 999., 2., -999.], dtype=float32)
如果您只想替换 NaN 值而保留
inf
值不变,那么使用where()
和jax.numpy.isnan()
会是更好的选择>>> jnp.where(jnp.isnan(x), 0, x) Array([ 0., 0., 1., inf, 2., -inf], dtype=float32)