XLA 编译器标志#

简介#

本指南简要概述了 XLA 以及 XLA 与 Jax 之间的关系。有关详细信息,请参阅 XLA 文档

XLA:Jax 背后的动力引擎#

XLA (Accelerated Linear Algebra) 是一种专门用于线性代数的领域特定编译器,在 Jax 的性能和灵活性方面发挥着核心作用。它通过将您的 Python/NumPy 风格代码转换并编译为高效的机器指令,使 Jax 能够为各种硬件后端(CPU、GPU、TPU)生成优化后的代码。

Jax 利用 XLA 的 JIT(即时)编译功能,在运行时将您的 Python 函数转换为优化后的 XLA 计算。

在 Jax 中配置 XLA:#

您可以通过在运行 Python 脚本或 Colab 笔记本之前设置 XLA_FLAGS 环境变量,来影响 XLA 在 Jax 中的行为。

对于 Colab 笔记本

使用 os.environ['XLA_FLAGS'] 提供标志

import os

# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'

对于 Python 脚本

XLA_FLAGS 指定为 CLI 命令的一部分

XLA_FLAGS='--flag1=value1 --flag2=value2'  python3 source.py

重要注意事项

  • 请在导入 Jax 或其他相关库之前设置 XLA_FLAGS。在后端初始化之后更改 XLA_FLAGS 将不会产生任何效果,且由于后端初始化时间并未明确定义,因此在执行任何 Jax 代码之前设置 XLA_FLAGS 通常更为安全。

  • 尝试使用不同的标志,以便为您的特定用例优化性能。

更多信息

  • 有关 XLA 的完整且最新的文档,请查阅官方 XLA 文档

  • 对于 XLA 开源版本支持的后端(CPU、GPU),XLA 标志及其默认值定义在 xla/debug_options_flags.cc 中,标志的完整列表可在此处找到:here

  • 有关如何使用关键 XLA 标志的指南,请见此处

延伸阅读