XLA 编译器标志#
简介#
本指南简要介绍了 XLA 以及 XLA 与 Jax 的关系。有关深入细节,请参阅XLA 文档。
XLA:Jax 背后的强大引擎#
XLA (Accelerated Linear Algebra) 是一个用于线性代数的领域特定编译器,它在 Jax 的性能和灵活性方面发挥着关键作用。它通过将您的 Python/NumPy 风格代码转换为高效的机器指令,使 Jax 能够为各种硬件后端(CPU、GPU、TPU)生成优化代码。
Jax 使用 XLA 的 JIT 编译功能在运行时将您的 Python 函数转换为优化的 XLA 计算。
在 Jax 中配置 XLA:#
您可以通过在运行 Python 脚本或 Colab 笔记本之前设置 XLA_FLAGS 环境变量来影响 XLA 在 Jax 中的行为。
对于 Colab 笔记本
使用 os.environ['XLA_FLAGS']
提供标志
import os
# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'
对于 Python 脚本
将 XLA_FLAGS
指定为 CLI 命令的一部分
XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py
重要提示
在导入 Jax 或其他相关库之前设置
XLA_FLAGS
。在后端初始化后更改XLA_FLAGS
将无效,并且鉴于后端初始化时间未明确定义,通常更安全地在执行任何 Jax 代码之前设置XLA_FLAGS
。尝试不同的标志以优化您特定用例的性能。
更多信息
关于 XLA 的完整和最新文档可在官方XLA 文档中找到。
对于 XLA 开源版本支持的后端(CPU、GPU),XLA 标志及其默认值定义在xla/debug_options_flags.cc中,并且完整的标志列表可以在此处找到。
关于如何使用关键 XLA 标志的指南可以在此处找到。
延伸阅读