jax.device_count#

jax.device_count(backend=None)[源代码]#

返回设备总数。

在大多数平台上,这与 jax.local_device_count() 相同。但是,在多进程平台上,不同设备可能与不同进程关联,在这种情况下,此函数将返回所有进程中的设备总数。

参数:

backend (str | xla_client.Client | None) – 这是一个实验性功能,其 API 可能会更改。可选,一个字符串,表示 XLA 后端:'cpu''gpu''tpu'

返回:

设备数量。

返回类型:

int