XLA 编译器标志#
引言#
本指南简要介绍了 XLA 以及 XLA 与 JAX 的关系。有关详细信息,请参阅 XLA 文档。
XLA:JAX 背后的强大引擎#
XLA(Accelerated Linear Algebra,加速线性代数)是专用于线性代数的编译器,在 JAX 的性能和灵活性方面发挥着关键作用。通过转换和编译您的 Python/NumPy 式代码为高效的机器指令,它使 JAX 能够为各种硬件后端(CPU、GPU、TPU)生成优化代码。
JAX 利用 XLA 的 JIT 编译功能,在运行时将您的 Python 函数转换为优化的 XLA 计算。
在 JAX 中配置 XLA:#
您可以通过在运行 Python 脚本或 colab notebook 之前设置 XLA_FLAGS 环境变量来影响 JAX 中 XLA 的行为。
对于 Colab Notebook
使用 os.environ['XLA_FLAGS'] 提供标志
import os
# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'
对于 Python 脚本
在 CLI 命令中将 XLA_FLAGS 指定为一部分
XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py
重要说明
在导入 JAX 或其他相关库之前设置
XLA_FLAGS。在后端初始化后更改XLA_FLAGS将无效,并且由于后端初始化时间定义不明确,因此在执行任何 JAX 代码之前设置XLA_FLAGS通常更安全。尝试不同的标志以针对特定用例优化性能。
更多信息
有关 XLA 的完整且最新的文档可在官方 XLA 文档中找到。
对于 XLA 开源版本支持的后端(CPU、GPU),XLA 标志及其默认值在 xla/debug_options_flags.cc 中定义,完整的标志列表可在 此处找到。
有关如何使用关键 XLA 标志的指南可在 此处找到。
附加阅读