jax.device_count#
- jax.device_count(backend=None)[源代码]#
返回设备总数。
在大多数平台上,这与
jax.local_device_count()
相同。但是,在不同设备与不同进程关联的多进程平台上,这将返回所有进程的设备总数。- 参数:
backend (str | xla_client.Client | None | None) – 这是一个实验性功能,API 可能会发生变化。可选,表示 xla 后端的字符串:
'cpu'
、'gpu'
或'tpu'
。- 返回:
设备数量。
- 返回类型: