jax.typing 模块#

JAX 类型模块是 JAX 特有的静态类型注解的存放位置。此子模块正在开发中;要了解此处导出的类型背后的提案,请参阅 https://jax.net.cn/en/latest/jep/12049-type-annotations.html

目前可用的类型包括:

  • jax.Array: 用于任何 JAX 数组或追踪器(即 JAX 转换中的数组表示)的注解。

  • jax.typing.ArrayLike: 用于任何可以安全地隐式转换为 JAX 数组的值的注解;这包括 jax.Arraynumpy.ndarray,以及 Python 内置数值类型(例如 intfloat 等)和 NumPy 标量值(例如 numpy.int32numpy.float64 等)。

  • jax.typing.DTypeLike: 用于任何可以转换为 JAX 兼容 dtype 的值的注解;这包括字符串(例如 ‘float32’‘int32’)、标量类型(例如 floatnp.float32)、dtypes(例如 np.dtype(‘float32’))或具有 dtype 属性的对象(例如 jnp.float32jnp.int32)。

我们可能会在未来的版本中在此处添加其他类型。

JAX 类型最佳实践#

在公共 API 函数中注解 JAX 数组时,我们建议对数组输入使用 ArrayLike,对数组输出使用 Array

例如,您的函数可能如下所示:

import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

def my_function(x: ArrayLike) -> Array:
  # Runtime type validation, Python 3.10 or newer:
  if not isinstance(x, ArrayLike):
    raise TypeError(f"Expected arraylike input; got {x}")
  # Runtime type validation, any Python version:
  if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
    raise TypeError(f"Expected arraylike input; got {x}")

  # Convert input to jax.Array:
  x_arr = jnp.asarray(x)

  # ... do some computation; JAX functions will return Array types:
  result = x_arr.sum(0) / x_arr.shape[0]

  # return an Array
  return result

大多数 JAX 的公共 API 都遵循此模式。请特别注意,我们建议 JAX 函数不要接受诸如 listtuple 之类的序列来代替数组,因为这可能会在诸如 jit() 之类的 JAX 转换中导致额外的开销,并且在诸如 vmap()jax.pmap() 之类的批处理转换中可能以意想不到的方式运行。有关此方面的更多信息,请参阅 非数组输入:NumPy 与 JAX

成员列表#

ArrayLike

JAX 数组类对象的类型注解。

DTypeLike

别名,类型为 str | type[Any] | dtype | SupportsDType