XLA 编译器标志#
简介#
本指南简要概述了 XLA 以及 XLA 与 Jax 之间的关系。有关详细信息,请参阅 XLA 文档。
XLA:Jax 背后的动力引擎#
XLA (Accelerated Linear Algebra) 是一种专门用于线性代数的领域特定编译器,在 Jax 的性能和灵活性方面发挥着核心作用。它通过将您的 Python/NumPy 风格代码转换并编译为高效的机器指令,使 Jax 能够为各种硬件后端(CPU、GPU、TPU)生成优化后的代码。
Jax 利用 XLA 的 JIT(即时)编译功能,在运行时将您的 Python 函数转换为优化后的 XLA 计算。
在 Jax 中配置 XLA:#
您可以通过在运行 Python 脚本或 Colab 笔记本之前设置 XLA_FLAGS 环境变量,来影响 XLA 在 Jax 中的行为。
对于 Colab 笔记本
使用 os.environ['XLA_FLAGS'] 提供标志
import os
# Set multiple flags separated by spaces
os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'
对于 Python 脚本
将 XLA_FLAGS 指定为 CLI 命令的一部分
XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py
重要注意事项
请在导入 Jax 或其他相关库之前设置
XLA_FLAGS。在后端初始化之后更改XLA_FLAGS将不会产生任何效果,且由于后端初始化时间并未明确定义,因此在执行任何 Jax 代码之前设置XLA_FLAGS通常更为安全。尝试使用不同的标志,以便为您的特定用例优化性能。
更多信息
有关 XLA 的完整且最新的文档,请查阅官方 XLA 文档。
对于 XLA 开源版本支持的后端(CPU、GPU),XLA 标志及其默认值定义在 xla/debug_options_flags.cc 中,标志的完整列表可在此处找到:here。
有关如何使用关键 XLA 标志的指南,请见此处。
延伸阅读