jax.numpy.isscalar#
- jax.numpy.isscalar(element)[源代码]#
如果输入是标量,则返回 True。
JAX 对
numpy.isscalar()的实现。JAX 的实现与 NumPy 的不同之处在于,它将零维数组视为标量;有关更多详细信息,请参阅下方的注意。- 参数:
element (Any) – 要检查的输入对象;任何类型都可作为有效输入。
- 返回:
如果
element是标量值或具有零维的类数组对象,则返回 True,否则返回 False。- 返回类型:
注意
JAX 和 NumPy 在标量值的表示上有所不同。NumPy 有特殊的标量对象(例如
np.int32(0)),它们与零维数组(例如np.array(0))不同,而numpy.isscalar()对于前者返回True,对于后者返回False。JAX 不定义特殊的标量对象,而是将标量表示为零维数组。因此,
jax.numpy.isscalar()对于标量对象(例如0.0或np.float32(0.0))和具有零维的类数组对象(例如jnp.array(0.0)、np.array(0.0))都返回True。在
isscalar中采用不同约定的原因之一是为了保持 JIT 不变性:即函数的结果在 JIT 编译时不会改变。由于标量输入在 JIT 边界处会被转换为零维 JAX 数组,因此numpy.isscalar()的语义使得结果在 JIT 下会发生变化。>>> np.isscalar(1.0) True >>> jax.jit(np.isscalar)(1.0) Array(False, dtype=bool)
通过将零维数组视为标量,
jax.numpy.isscalar()避免了这个问题。>>> jnp.isscalar(1.0) True >>> jax.jit(jnp.isscalar)(1.0) Array(True, dtype=bool)
示例
在 JAX 中,标量和零维类数组对象都被视为标量。
>>> jnp.isscalar(1.0) True >>> jnp.isscalar(1 + 1j) True >>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array True >>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor True >>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array True >>> jnp.isscalar(np.int32(1)) # NumPy scalar type True
维度为一或多维的数组不被视为标量。
>>> jnp.isscalar(jnp.array([1])) False >>> jnp.isscalar(np.array([1])) False
与
numpy.isscalar()相比,后者对于标量类型对象返回True,而对于所有数组,即使是零维数组,都返回False。>>> np.isscalar(np.int32(1)) # scalar object True >>> np.isscalar(np.array(1)) # zero-dimensional array False
在 JAX 中,与 NumPy 一样,非类数组对象不被视为标量。
>>> jnp.isscalar(None) False >>> jnp.isscalar([1]) False >>> jnp.isscalar(()) False >>> jnp.isscalar(slice(10)) False