jax.jit

目录

jax.jit#

jax.jit(fun: Callable, /, *, in_shardings: Any = UnspecifiedValue, out_shardings: Any = UnspecifiedValue, static_argnums: int | Sequence[int] | None = None, static_argnames: str | Iterable[str] | None = None, donate_argnums: int | Sequence[int] | None = None, donate_argnames: str | Iterable[str] | None = None, keep_unused: bool = False, device: xc.Device | None = None, backend: str | None = None, inline: bool = False, compiler_options: dict[str, Any] | None = None) pjit.JitWrapped[source]#
jax.jit(*, in_shardings: Any = UnspecifiedValue, out_shardings: Any = UnspecifiedValue, static_argnums: int | Sequence[int] | None = None, static_argnames: str | Iterable[str] | None = None, donate_argnums: int | Sequence[int] | None = None, donate_argnames: str | Iterable[str] | None = None, keep_unused: bool = False, device: xc.Device | None = None, backend: str | None = None, inline: bool = False, compiler_options: dict[str, Any] | None = None) Callable[[Callable], pjit.JitWrapped]

fun 设置 JIT(即时)编译,以适配 XLA。

参数:
  • fun – 待 JIT 编译的函数。fun 应当是一个纯函数。其参数和返回值应当为数组、标量或此类对象的(嵌套)标准 Python 容器(元组/列表/字典)。由 static_argnums 指定的位置参数可以是任何可哈希类型。静态参数将包含在编译缓存键中,因此必须定义哈希和相等运算符。JAX 会保留指向 fun 的弱引用,将其用作编译缓存键,因此对象 fun 必须是弱引用的。从 JAX v0.8.1 开始,如果省略 fun,返回值将为一个部分求值的函数,以支持装饰器工厂模式(请参阅下方的示例)。

  • in_shardings – 可选参数,一个 Sharding 或包含 Sharding 叶节点的 pytree,其结构是传递给 fun 的位置参数元组的树前缀。如果提供,则传递给 fun 的位置参数必须具有与 in_shardings 兼容的分片(Sharding),否则会报错;且编译后的计算将具有与 in_shardings 相对应的输入分片。若未提供,编译后计算的输入分片将从参数分片中推断得出。

  • out_shardings – 可选参数,一个 Sharding 或包含 Sharding 叶节点的 pytree,其结构是 fun 输出的树前缀。如果提供,其效果等同于将 jax.lax.with_sharding_constraint() 应用于 fun 的输出。

  • static_argnums

    可选参数,一个整数或整数集合,用于指定哪些位置参数应被视为静态(跟踪时和编译时的常量)。

    静态参数必须是可哈希的(即实现了 __hash____eq__)且不可变。除此之外,它们可以是任意 Python 对象。使用这些常量的不同值调用 JIT 编译后的函数将触发重新编译。非数组类(或非此类容器)的参数必须被标记为静态。

    如果既未提供 static_argnums 也未提供 static_argnames,则没有参数会被视为静态。如果仅提供了其中之一,JAX 会使用 inspect.signature(fun) 来查找与 static_argnames 对应的位置参数(反之亦然)。如果同时提供了 static_argnumsstatic_argnames,则不使用 inspect.signature,只有列在两者中的实际参数才会被视为静态。

  • static_argnames – 可选参数,一个字符串或字符串集合,指定哪些命名参数应被视为静态(编译时常量)。详情请参阅 static_argnums 的说明。如果未提供该参数但设置了 static_argnums,则默认行为基于调用 inspect.signature(fun) 来查找对应的命名参数。

  • donate_argnums

    可选参数,一个整数集合,用于指定哪些位置参数缓冲区可以被计算覆盖,并在调用方中标记为已删除。如果您在计算开始后不再需要参数缓冲区,那么捐赠它们是安全的。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如重用输入缓冲区来存储结果。您不应重复使用捐赠给计算的缓冲区;如果您尝试这样做,JAX 将报错。默认情况下,不捐赠任何参数缓冲区。

    如果既未提供 donate_argnums 也未提供 donate_argnames,则不捐赠任何参数。如果仅提供了其中之一,JAX 会使用 inspect.signature(fun) 来查找对应的参数。如果同时提供了 donate_argnumsdonate_argnames,则不使用 inspect.signature,只有列在两者中的实际参数才会被捐赠。

    有关缓冲区捐赠的更多详细信息,请参阅 FAQ

  • donate_argnames – 可选参数,一个字符串或字符串集合,指定哪些命名参数将被捐赠给计算过程。详情请参阅 donate_argnums 的说明。如果未提供该参数但设置了 donate_argnums,则默认行为基于调用 inspect.signature(fun) 来查找对应的命名参数。

  • keep_unused – 可选布尔值。如果为 False(默认值),JAX 确定在 fun 中未使用的参数可能会从最终编译的 XLA 可执行文件中剔除。此类参数既不会传输到设备,也不会提供给底层的可执行文件。如果为 True,则不会修剪未使用的参数。

  • device – 这是一个实验性功能,API 可能会发生变化。可选参数,指定 JIT 编译函数运行的设备。(可通过 jax.devices() 获取可用设备。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常使用 jax.devices()[0]

  • backend – 这是一个实验性功能,API 可能会发生变化。可选参数,一个代表 XLA 后端的字符串:'cpu''gpu''tpu'

  • inline – 可选布尔值。指定此函数是否应内联到封闭的 jaxpr 中。默认为 False。

返回:

fun 的包装版本,已设置为 JIT 编译。

示例

在以下示例中,selu 可以被 XLA 编译成单个融合内核(fused kernel)

>>> import jax
>>>
>>> @jax.jit
... def selu(x, alpha=1.67, lmbda=1.05):
...   return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.key(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x))  
[-0.54485  0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232  0.76827  0.59566 ]

从 JAX v0.8.1 开始,jit() 支持装饰器工厂模式,用于指定可选关键字

>>> @jax.jit(static_argnames=['n'])
... def g(x, n):
...   for i in range(n):
...     x = x ** 2
...   return x
>>>
>>> g(jnp.arange(4), 3)
Array([   0,    1,  256, 6561], dtype=int32)

为了与旧版本的 JAX 兼容,一种常见的模式是使用 functools.partial()

>>> from functools import partial
>>>
>>> @partial(jax.jit, static_argnames=['n'])
... def g(x, n):
...   for i in range(n):
...     x = x ** 2
...   return x
>>>
>>> g(jnp.arange(4), 3)
Array([   0,    1,  256, 6561], dtype=int32)