jax.numpy.isscalar#
- jax.numpy.isscalar(element)[源]#
如果输入是标量,则返回 True。
JAX 对
numpy.isscalar()
的实现。JAX 的实现与 NumPy 的不同之处在于,它将零维数组视为标量;请参阅下面的注意以了解更多详情。- 参数:
element (任意类型) – 要检查的输入对象;任何类型都是有效输入。
- 返回:
如果
element
是标量值或零维类数组对象,则返回 True,否则返回 False。- 返回类型:
注意
JAX 和 NumPy 在标量值的表示上有所不同。NumPy 具有特殊的标量对象(例如
np.int32(0)
),这些对象与零维数组(例如np.array(0)
)是不同的,并且numpy.isscalar()
对于前者返回True
,对于后者返回False
。JAX 不定义特殊的标量对象,而是将标量表示为零维数组。因此,
jax.numpy.isscalar()
对于标量对象(例如0.0
或np.float32(0.0)
)和零维类数组对象(例如jnp.array(0.0)
,np.array(0.0)
)都返回True
。中不同约定的一个原因是维持 JIT 不变性:即函数在 JIT 编译后其结果不应改变的属性。由于标量输入在 JIT 边界处被转换为零维 JAX 数组,因此
numpy.isscalar()
的语义导致结果在 JIT 下发生变化。>>> np.isscalar(1.0) True >>> jax.jit(np.isscalar)(1.0) Array(False, dtype=bool)
通过将零维数组视为标量,
jax.numpy.isscalar()
避免了这个问题。>>> jnp.isscalar(1.0) True >>> jax.jit(jnp.isscalar)(1.0) Array(True, dtype=bool)
示例
在 JAX 中,标量和零维类数组对象都被视为标量。
>>> jnp.isscalar(1.0) True >>> jnp.isscalar(1 + 1j) True >>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array True >>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor True >>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array True >>> jnp.isscalar(np.int32(1)) # NumPy scalar type True
一维或多维数组不被视为标量。
>>> jnp.isscalar(jnp.array([1])) False >>> jnp.isscalar(np.array([1])) False
将其与
numpy.isscalar()
进行比较,后者对于标量类型对象返回True
,对于所有数组(即使是零维数组)返回False
。>>> np.isscalar(np.int32(1)) # scalar object True >>> np.isscalar(np.array(1)) # zero-dimensional array False
在 JAX 中,与 NumPy 中一样,非类数组对象不被视为标量。
>>> jnp.isscalar(None) False >>> jnp.isscalar([1]) False >>> jnp.isscalar(tuple()) False >>> jnp.isscalar(slice(10)) False