jax.lib.xla_bridge.get_compile_options#

jax.lib.xla_bridge.get_compile_options(num_replicas, num_partitions, device_assignment=None, use_spmd_partitioning=True, use_auto_spmd_partitioning=False, auto_spmd_partitioning_mesh_shape=None, auto_spmd_partitioning_mesh_ids=None, env_options_overrides=None, fdo_profile=None, detailed_logging=True, backend=None)[源代码]#

返回要使用的编译选项,该选项从标志值派生。

参数:
  • num_replicas (int) – 要编译的副本数。

  • num_partitions (int) – 要编译的分区数。

  • device_assignment – jax 设备的ndarray 可选,指示逻辑副本到物理设备的分配(默认从 xla_client.CompileOptions 继承)。必须与 num_replicasnum_partitions 一致。

  • use_spmd_partitioning (bool) – 布尔值,指示是否在 XLA 中启用 SPMD 或 MPMD 分区。

  • use_auto_spmd_partitioning (bool) – 布尔值,指示是否为 SPMD 分区程序自动生成 XLA 分片。

  • auto_spmd_partitioning_mesh_shape (list[int] | None) – 用于创建 auto_spmd_partitioning 搜索空间的设备网格形状。

  • auto_spmd_partitioning_mesh_ids (list[int] | None) – 用于创建 auto_spmd_partitioning 搜索空间的设备 ID。

  • env_options_overrides (dict[str, str] | None) – 编译器分析的其他选项的字典

  • fdo_profile (bytes | None) – 用于反馈定向优化的可选配置文件,传递给 XLA。

  • detailed_logging (bool) – 这是否是关于 XLA 应该记录编译信息的“有趣的”计算?

  • backend (xc.Client | None) – 客户端(如果可用)。

返回类型:

xc.CompileOptions