jax.typing
模块#
JAX 类型注解模块是 JAX 特定的静态类型注解所在的地方。这个子模块仍在开发中;要查看此处导出的类型背后的提议,请参阅 https://jax.net.cn/en/latest/jep/12049-type-annotations.html。
目前可用的类型有
jax.Array
:用于任何 JAX 数组或追踪器(即 JAX 转换中数组的表示)的注解。jax.typing.ArrayLike
:用于任何可以安全地隐式转换为 JAX 数组的值的注解;这包括jax.Array
、numpy.ndarray
,以及 Python 内置的数值(例如int
、float
等)和 NumPy 标量值(例如numpy.int32
、numpy.float64
等)。jax.typing.DTypeLike
:用于任何可以转换为 JAX 兼容 dtype 的值的注解;这包括字符串(例如 ‘float32’、‘int32’)、标量类型(例如 float、np.float32)、dtypes(例如 np.dtype(‘float32’)),或具有 dtype 属性的对象(例如 jnp.float32、jnp.int32)。
我们可能会在未来的版本中添加更多的类型。
JAX 类型注解最佳实践#
在公共 API 函数中注解 JAX 数组时,我们建议对数组输入使用 ArrayLike
,对数组输出使用 Array
。
例如,你的函数可能看起来像这样
import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike
def my_function(x: ArrayLike) -> Array:
# Runtime type validation, Python 3.10 or newer:
if not isinstance(x, ArrayLike):
raise TypeError(f"Expected arraylike input; got {x}")
# Runtime type validation, any Python version:
if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
raise TypeError(f"Expected arraylike input; got {x}")
# Convert input to jax.Array:
x_arr = jnp.asarray(x)
# ... do some computation; JAX functions will return Array types:
result = x_arr.sum(0) / x_arr.shape[0]
# return an Array
return result
大多数 JAX 的公共 API 都遵循这种模式。特别要注意的是,我们建议 JAX 函数不接受序列(例如 list
或 tuple
)来代替数组,因为这会在 jit()
等 JAX 转换中导致额外的开销,并且在 vmap()
或 jax.pmap()
等批量转换中可能表现出意想不到的行为。有关更多信息,请参阅 非数组输入 NumPy 与 JAX 的区别
成员列表#
JAX 类数组对象的类型注解。 |
|