jax.numpy.isscalar#

jax.numpy.isscalar(element)[source]#

如果输入是标量,则返回 True。

JAX 实现的 numpy.isscalar()。JAX 的实现与 NumPy 的不同之处在于,它认为零维数组是标量;详见下面的注意

参数:

element (Any) – 要检查的输入对象;任何类型都是有效的输入。

返回:

如果 element 是标量值或零维类数组对象,则返回 True,否则返回 False。

返回类型:

bool

注意

JAX 和 NumPy 在标量值的表示上有所不同。NumPy 有特殊的标量对象(例如 np.int32(0)),这些对象与零维数组(例如 np.array(0))不同,并且 numpy.isscalar() 对于前者返回 True,对于后者返回 False

JAX 没有定义特殊的标量对象,而是将标量表示为零维数组。因此,jax.numpy.isscalar() 对于标量对象(例如 0.0np.float32(0.0))和零维类数组对象(例如 jnp.array(0.0), np.array(0.0)) 都返回 True

isscalar 中采用不同约定的一个原因是保持 JIT 不变性:即,当函数经过 JIT 编译时,其结果不应发生变化。由于标量输入在 JIT 边界处被转换为零维 JAX 数组,因此 numpy.isscalar() 的语义使得结果在 JIT 下会发生变化

>>> np.isscalar(1.0)
True
>>> jax.jit(np.isscalar)(1.0)
Array(False, dtype=bool)

通过将零维数组视为标量,jax.numpy.isscalar() 避免了这个问题

>>> jnp.isscalar(1.0)
True
>>> jax.jit(jnp.isscalar)(1.0)
Array(True, dtype=bool)

示例

在 JAX 中,标量和零维类数组对象都被认为是标量

>>> jnp.isscalar(1.0)
True
>>> jnp.isscalar(1 + 1j)
True
>>> jnp.isscalar(jnp.array(1))  # zero-dimensional JAX array
True
>>> jnp.isscalar(jnp.int32(1))  # JAX scalar constructor
True
>>> jnp.isscalar(np.array(1.0))  # zero-dimensional NumPy array
True
>>> jnp.isscalar(np.int32(1))  # NumPy scalar type
True

具有一个或多个维度的数组不被认为是标量

>>> jnp.isscalar(jnp.array([1]))
False
>>> jnp.isscalar(np.array([1]))
False

numpy.isscalar() 比较,后者对于标量类型对象返回 True,对于所有数组(即使是零维数组)返回 False

>>> np.isscalar(np.int32(1))  # scalar object
True
>>> np.isscalar(np.array(1))  # zero-dimensional array
False

在 JAX 中,与 NumPy 中一样,非类数组对象不被认为是标量

>>> jnp.isscalar(None)
False
>>> jnp.isscalar([1])
False
>>> jnp.isscalar(tuple())
False
>>> jnp.isscalar(slice(10))
False