jax.export.export#

jax.export.export(fun_jit, *, platforms=None, disabled_checks=(), _override_lowering_rules=None)[源代码]#

导出 JAX 函数以进行持久序列化。

参数:
  • fun_jit (stages.Wrapped) – 要导出的函数。应为 jax.jit 的结果。

  • platforms (Sequence[str] | None | None) – 可选序列,包含 ‘tpu’、‘cpu’、‘cuda’、‘rocm’ 的子集。如果指定了多个平台,则导出的代码会接受一个参数来指定平台。如果为 None,则使用默认 JAX 后端。有关多个平台的调用约定,请参见 https://jax.net.cn/en/latest/export/export.html#module-calling-convention

  • _override_lowering_rules (Sequence[tuple[Any, Any]] | None | None) – 一组可选的自定义 lowering 规则,用于某些 JAX 原语。序列的每个元素都是一个 JAX 原语和一个 lowering 函数对。定义 lowering 规则是一项高级功能,使用 JAX 内部 API,这些 API 可能会发生更改。此外,通过这些自定义 lowering 规则发出的 MLIR 的稳定性责任由这些规则的用户承担。

  • disabled_checks (Sequence[DisabledSafetyCheck]) – 要禁用的安全检查。请参阅 jax.export.DisabledSafetyCheck 的文档。

返回:

一个函数,它接受 {class}`jax.ShapeDtypeStruct` 的 args 和 kwargs pytrees,或具有 .shape.dtype 属性的值,并返回一个 Exported

返回类型:

Callable[…, Exported]

用法

>>> from jax import export
>>> exported: export.Exported = export.export(jnp.sin)(
...     np.arange(4, dtype=np.float32))
>>>
>>> # You can inspect the Exported object
>>> exported.in_avals
(ShapedArray(float32[4]),)
>>> blob: bytearray = exported.serialize()
>>>
>>> # The serialized bytes are safe to use in a separate process
>>> rehydrated: export.Exported = export.deserialize(blob)
>>> rehydrated.fun_name
'sin'
>>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32))
Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32)