jax.export.export#
- jax.export.export(fun_jit, *, platforms=None, disabled_checks=(), _override_lowering_rules=None)[源代码]#
导出 JAX 函数以进行持久化序列化。
- 参数:
fun_jit (stages.Wrapped) – 要导出的函数。应该是
jax.jit()的结果。platforms (Sequence[str] | None) – 可选序列,包含 'tpu'、'cpu'、'cuda'、'rocm' 的子集。如果指定了多个平台,则导出的代码将接受一个指定平台的参数。如果为 None,则使用默认的 JAX 后端。多平台调用约定在 https://jax.net.cn/en/latest/export/export.html#module-calling-convention 中进行了解释。
_override_lowering_rules (Sequence[tuple[Any, Any]] | None) – 用于某些 JAX 原语的可选自定义降低规则序列。序列的每个元素都是一个 JAX 原语和一个降低函数的对。定义降低规则是一项高级功能,使用了 JAX 的内部 API,这些 API 可能会发生变化。此外,通过这些自定义降低规则发出的 MLIR 的稳定性责任由这些规则的用户承担。
disabled_checks (Sequence[DisabledSafetyCheck]) – 要禁用的安全检查。请参阅
jax.export.DisabledSafetyCheck的文档。
- 返回:
一个函数,该函数接受
jax.ShapeDtypeStruct的 args 和 kwargs Pytree,或具有.shape和.dtype属性的值,并返回一个Exported。- 返回类型:
Callable[…, Exported]
用法
>>> from jax import export >>> exported: export.Exported = export.export(jnp.sin)( ... np.arange(4, dtype=np.float32)) >>> >>> # You can inspect the Exported object >>> exported.in_avals (ShapedArray(float32[4]),) >>> blob: bytearray = exported.serialize() >>> >>> # The serialized bytes are safe to use in a separate process >>> rehydrated: export.Exported = export.deserialize(blob) >>> rehydrated.fun_name 'sin' >>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32)) Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32)