jax.numpy.heaviside#
- jax.numpy.heaviside(x1, x2, /)[源代码]#
计算 heaviside 阶跃函数。
numpy.heaviside
的 JAX 实现。heaviside 阶跃函数定义为
\[\begin{split}\mathrm{heaviside}(x1, x2) = \begin{cases} 0, & x1 < 0\\ x2, & x1 = 0\\ 1, & x1 > 0. \end{cases}\end{split}\]- 参数:
x1 (ArrayLike) – 输入数组或标量。不支持
complex
dtype。x2 (ArrayLike) – 标量或数组。指定当
x1
为0
时的返回值。不支持complex
dtype。x1
和x2
必须具有相同的形状或广播兼容。
- 返回:
一个包含
x1
的 heaviside 阶跃函数的数组,提升为非精确 dtype。- 返回类型:
示例
>>> x1 = jnp.array([[-2, 0, 3], ... [5, -1, 0], ... [0, 7, -3]]) >>> x2 = jnp.array([2, 0.5, 1]) >>> jnp.heaviside(x1, x2) Array([[0. , 0.5, 1. ], [1. , 0. , 1. ], [2. , 1. , 0. ]], dtype=float32) >>> jnp.heaviside(x1, 0.5) Array([[0. , 0.5, 1. ], [1. , 0. , 0.5], [0.5, 1. , 0. ]], dtype=float32) >>> jnp.heaviside(-3, x2) Array([0., 0., 0.], dtype=float32)