jax.numpy.isscalar#

jax.numpy.isscalar(element)[源代码]#

如果输入是标量,则返回 True。

JAX 对 numpy.isscalar() 的实现。JAX 的实现与 NumPy 的不同之处在于,它将零维数组视为标量;有关更多详细信息,请参阅下方的注意

参数:

element (Any) – 要检查的输入对象;任何类型都可作为有效输入。

返回:

如果 element 是标量值或具有零维的类数组对象,则返回 True,否则返回 False。

返回类型:

bool

注意

JAX 和 NumPy 在标量值的表示上有所不同。NumPy 有特殊的标量对象(例如 np.int32(0)),它们与零维数组(例如 np.array(0))不同,而 numpy.isscalar() 对于前者返回 True,对于后者返回 False

JAX 不定义特殊的标量对象,而是将标量表示为零维数组。因此,jax.numpy.isscalar() 对于标量对象(例如 0.0np.float32(0.0))和具有零维的类数组对象(例如 jnp.array(0.0)np.array(0.0))都返回 True

isscalar 中采用不同约定的原因之一是为了保持 JIT 不变性:即函数的结果在 JIT 编译时不会改变。由于标量输入在 JIT 边界处会被转换为零维 JAX 数组,因此 numpy.isscalar() 的语义使得结果在 JIT 下会发生变化。

>>> np.isscalar(1.0)
True
>>> jax.jit(np.isscalar)(1.0)
Array(False, dtype=bool)

通过将零维数组视为标量,jax.numpy.isscalar() 避免了这个问题。

>>> jnp.isscalar(1.0)
True
>>> jax.jit(jnp.isscalar)(1.0)
Array(True, dtype=bool)

示例

在 JAX 中,标量和零维类数组对象都被视为标量。

>>> jnp.isscalar(1.0)
True
>>> jnp.isscalar(1 + 1j)
True
>>> jnp.isscalar(jnp.array(1))  # zero-dimensional JAX array
True
>>> jnp.isscalar(jnp.int32(1))  # JAX scalar constructor
True
>>> jnp.isscalar(np.array(1.0))  # zero-dimensional NumPy array
True
>>> jnp.isscalar(np.int32(1))  # NumPy scalar type
True

维度为一或多维的数组不被视为标量。

>>> jnp.isscalar(jnp.array([1]))
False
>>> jnp.isscalar(np.array([1]))
False

numpy.isscalar() 相比,后者对于标量类型对象返回 True,而对于所有数组,即使是零维数组,都返回 False

>>> np.isscalar(np.int32(1))  # scalar object
True
>>> np.isscalar(np.array(1))  # zero-dimensional array
False

在 JAX 中,与 NumPy 一样,非类数组对象不被视为标量。

>>> jnp.isscalar(None)
False
>>> jnp.isscalar([1])
False
>>> jnp.isscalar(())
False
>>> jnp.isscalar(slice(10))
False