jax.numpy.issubdtype#
- jax.numpy.issubdtype(arg1, arg2)[源代码]#
如果 arg1 在类型层次结构中等于或低于 arg2,则返回 True。
JAX 对
numpy.issubdtype()的实现。JAX 实现的主要区别在于它能正确处理 dtype 扩展,例如
bfloat16。- 参数:
arg1 (DTypeLike) – 类 dtype 对象。典型用法中,这将是一个 dtype 说明符,例如
"float32"(即字符串)、np.dtype('int32')(即numpy.dtype的实例)、jnp.complex64(即 JAX 标量构造函数)或np.uint8(即 NumPy 标量类型)。arg2 (DTypeLike) – 类 dtype 对象。典型用法中,这将是一个通用标量类型,例如
jnp.integer、jnp.floating或jnp.complexfloating。
- 返回:
如果 arg1 表示的 dtype 在类型层次结构中等于或低于 arg2,则为 True。
- 返回类型:
另请参阅
jax.numpy.isdtype():一个类似的函数,符合 Array API 标准。
示例
>>> jnp.issubdtype('uint32', jnp.unsignedinteger) True >>> jnp.issubdtype(np.int32, jnp.integer) True >>> jnp.issubdtype(jnp.bfloat16, jnp.floating) True >>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating) True >>> jnp.issubdtype('complex64', jnp.integer) False
请注意,虽然这与
numpy.issubdtype()非常相似,但在 JAX 的自定义浮点类型情况下,这些函数的结果会有所不同。>>> np.issubdtype('bfloat16', np.floating) False >>> jnp.issubdtype('bfloat16', jnp.floating) True