XLA 编译器标志#

简介#

本指南简要介绍了 XLA 以及 XLA 与 Jax 的关系。有关深入细节,请参阅XLA 文档

XLA:Jax 背后的强大引擎#

XLA (Accelerated Linear Algebra) 是一个用于线性代数的领域特定编译器,它在 Jax 的性能和灵活性方面发挥着关键作用。它通过将您的 Python/NumPy 风格代码转换为高效的机器指令,使 Jax 能够为各种硬件后端(CPU、GPU、TPU)生成优化代码。

Jax 使用 XLA 的 JIT 编译功能在运行时将您的 Python 函数转换为优化的 XLA 计算。

在 Jax 中配置 XLA:#

您可以通过在运行 Python 脚本或 Colab 笔记本之前设置 XLA_FLAGS 环境变量来影响 XLA 在 Jax 中的行为。

对于 Colab 笔记本

使用 os.environ['XLA_FLAGS'] 提供标志

import os

# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'

对于 Python 脚本

XLA_FLAGS 指定为 CLI 命令的一部分

XLA_FLAGS='--flag1=value1 --flag2=value2'  python3 source.py

重要提示

  • 在导入 Jax 或其他相关库之前设置 XLA_FLAGS。在后端初始化后更改 XLA_FLAGS 将无效,并且鉴于后端初始化时间未明确定义,通常更安全地在执行任何 Jax 代码之前设置 XLA_FLAGS

  • 尝试不同的标志以优化您特定用例的性能。

更多信息

  • 关于 XLA 的完整和最新文档可在官方XLA 文档中找到。

  • 对于 XLA 开源版本支持的后端(CPU、GPU),XLA 标志及其默认值定义在xla/debug_options_flags.cc中,并且完整的标志列表可以在此处找到。

  • 关于如何使用关键 XLA 标志的指南可以在此处找到。

延伸阅读