XLA 编译器标志#

引言#

本指南简要介绍了 XLA 以及 XLA 与 JAX 的关系。有关详细信息,请参阅 XLA 文档

XLA:JAX 背后的强大引擎#

XLA(Accelerated Linear Algebra,加速线性代数)是专用于线性代数的编译器,在 JAX 的性能和灵活性方面发挥着关键作用。通过转换和编译您的 Python/NumPy 式代码为高效的机器指令,它使 JAX 能够为各种硬件后端(CPU、GPU、TPU)生成优化代码。

JAX 利用 XLA 的 JIT 编译功能,在运行时将您的 Python 函数转换为优化的 XLA 计算。

在 JAX 中配置 XLA:#

您可以通过在运行 Python 脚本或 colab notebook 之前设置 XLA_FLAGS 环境变量来影响 JAX 中 XLA 的行为。

对于 Colab Notebook

使用 os.environ['XLA_FLAGS'] 提供标志

import os

# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'

对于 Python 脚本

在 CLI 命令中将 XLA_FLAGS 指定为一部分

XLA_FLAGS='--flag1=value1 --flag2=value2'  python3 source.py

重要说明

  • 在导入 JAX 或其他相关库之前设置 XLA_FLAGS。在后端初始化后更改 XLA_FLAGS 将无效,并且由于后端初始化时间定义不明确,因此在执行任何 JAX 代码之前设置 XLA_FLAGS 通常更安全。

  • 尝试不同的标志以针对特定用例优化性能。

更多信息

  • 有关 XLA 的完整且最新的文档可在官方 XLA 文档中找到。

  • 对于 XLA 开源版本支持的后端(CPU、GPU),XLA 标志及其默认值在 xla/debug_options_flags.cc 中定义,完整的标志列表可在 此处找到。

  • 有关如何使用关键 XLA 标志的指南可在 此处找到。

附加阅读