jax.lax.platform_dependent#

jax.lax.platform_dependent(*args, default=None, **per_platform)[source]#

支持平台特定代码。

在 JAX 中,计算实际运行的平台确定得非常晚,例如,根据数据所在位置确定。当使用 AOT 降低或序列化时,计算可能在不同的机器上编译和执行,甚至在降低时不可用的平台上。这意味着使用 Python 条件语句编写平台相关代码是不安全的,例如,基于当前默认的 JAX 平台。相反,可以使用platform_dependent

用法

def cpu_code(*args): ...
def tpu_code(*args): ...
def other_platforms_code(*args): ...
res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code,
                         default=other_platforms_code)

当阶段化代码在 CPU 上执行时,这等同于cpu_code(*args);在 TPU 上等同于tpu_code(*args);在任何其他平台上等同于other_platforms_code(*args)。与 Python 条件语句不同,所有替代方案都会被追踪并阶段化到 Jaxpr。这与switch()类似,并且是基于它实现的,它从该函数继承了在转换下的行为。

switch()不同,执行哪个分支的选择会更早做出:在大多数情况下,是在降低(lowering)过程中,当降低平台已知时;在多平台降低和序列化的罕见情况下,StableHLO 代码将包含一个基于实际平台的条件语句。该条件在编译之前即时解析,当编译平台已知时。这意味着编译器实际上永远不会看到条件语句。

参数:
  • *args (Any) – 传递给每个分支的 JAX 数组。可以是 PyTrees。

  • **per_platform (Callable[..., _T]) – 用于不同平台的计算分支。这些分支是使用*args调用的 JAX 可调用对象。关键字是平台名称,例如“cpu”、“tpu”、“cuda”、“rocm”。

  • default (Callable[..., _T] | None) – (可选)当平台未在per_platform中提及时使用的默认分支。如果没有default,则当代码在per_platform中未提及的平台降低时,将出现错误。

返回:

返回值per_platform[execution_platform](*args)