jax.numpy.isscalar#

jax.numpy.isscalar(element)[源]#

如果输入是标量,则返回 True。

JAX 对 numpy.isscalar() 的实现。JAX 的实现与 NumPy 的不同之处在于,它将零维数组视为标量;请参阅下面的注意以了解更多详情。

参数:

element (任意类型) – 要检查的输入对象;任何类型都是有效输入。

返回:

如果 element 是标量值或零维类数组对象,则返回 True,否则返回 False。

返回类型:

bool

注意

JAX 和 NumPy 在标量值的表示上有所不同。NumPy 具有特殊的标量对象(例如 np.int32(0)),这些对象与零维数组(例如 np.array(0))是不同的,并且 numpy.isscalar() 对于前者返回 True,对于后者返回 False

JAX 不定义特殊的标量对象,而是将标量表示为零维数组。因此,jax.numpy.isscalar() 对于标量对象(例如 0.0np.float32(0.0))和零维类数组对象(例如 jnp.array(0.0), np.array(0.0))都返回 True

中不同约定的一个原因是维持 JIT 不变性:即函数在 JIT 编译后其结果不应改变的属性。由于标量输入在 JIT 边界处被转换为零维 JAX 数组,因此 numpy.isscalar() 的语义导致结果在 JIT 下发生变化。

>>> np.isscalar(1.0)
True
>>> jax.jit(np.isscalar)(1.0)
Array(False, dtype=bool)

通过将零维数组视为标量,jax.numpy.isscalar() 避免了这个问题。

>>> jnp.isscalar(1.0)
True
>>> jax.jit(jnp.isscalar)(1.0)
Array(True, dtype=bool)

示例

在 JAX 中,标量和零维类数组对象都被视为标量。

>>> jnp.isscalar(1.0)
True
>>> jnp.isscalar(1 + 1j)
True
>>> jnp.isscalar(jnp.array(1))  # zero-dimensional JAX array
True
>>> jnp.isscalar(jnp.int32(1))  # JAX scalar constructor
True
>>> jnp.isscalar(np.array(1.0))  # zero-dimensional NumPy array
True
>>> jnp.isscalar(np.int32(1))  # NumPy scalar type
True

一维或多维数组不被视为标量。

>>> jnp.isscalar(jnp.array([1]))
False
>>> jnp.isscalar(np.array([1]))
False

将其与 numpy.isscalar() 进行比较,后者对于标量类型对象返回 True,对于所有数组(即使是零维数组)返回 False

>>> np.isscalar(np.int32(1))  # scalar object
True
>>> np.isscalar(np.array(1))  # zero-dimensional array
False

在 JAX 中,与 NumPy 中一样,非类数组对象不被视为标量。

>>> jnp.isscalar(None)
False
>>> jnp.isscalar([1])
False
>>> jnp.isscalar(tuple())
False
>>> jnp.isscalar(slice(10))
False