jax.numpy.array_equal#
- jax.numpy.array_equal(a1, a2, equal_nan=False)[来源]#
检查两个数组是否逐元素相等。
JAX 对
numpy.array_equal()
的实现。- 参数:
a1 (ArrayLike) – 第一个要比较的输入数组。
a2 (ArrayLike) – 第二个要比较的输入数组。
equal_nan (bool) – 布尔值。如果
True
,a1 中的 NaN 值将被视为与 a2 中的 NaN 值相等。默认为False
。
- 返回:
布尔标量数组,指示输入数组是否逐元素相等。
- 返回类型:
示例
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3])) Array(True, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')])) Array(False, dtype=bool) >>> jnp.array_equal(jnp.array([1, 2, float('nan')]), ... jnp.array([1, 2, float('nan')]), equal_nan=True) Array(True, dtype=bool)