jax.numpy.log1p#
- jax.numpy.log1p(x, /)[源]#
计算输入值加一后的逐元素对数,即
log(x+1)。JAX 对
numpy.log1p的实现。- 参数:
x (ArrayLike) – 输入数组或标量。
- 返回:
一个包含
x中每个元素加一后的对数的数组,会提升到非整数(inexact)数据类型。- 返回类型:
注意
对于
x的小值,jnp.log1p比使用朴素计算log(x+1)更准确。另请参阅
jax.numpy.expm1():计算输入的每个元素的 \(e^x-1\)。jax.numpy.log2():计算输入的每个元素的以2为底的对数。jax.numpy.log(): 计算输入的逐元素对数。
示例
>>> x = jnp.array([2, 5, 9, 4]) >>> jnp.allclose(jnp.log1p(x), jnp.log(x+1)) Array(True, dtype=bool)
对于非常接近 0 的值,
jnp.log1p(x)比jnp.log(x+1)更准确。>>> x1 = jnp.array([1e-4, 1e-6, 2e-10]) >>> jnp.expm1(jnp.log1p(x1)) Array([1.00000005e-04, 9.99999997e-07, 2.00000003e-10], dtype=float32) >>> jnp.expm1(jnp.log(x1+1)) Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32)