jax.devices#

jax.devices(backend=None)[source]#

返回给定后端的所有设备列表。

每个设备都由 Device (例如 CpuDevice, GpuDevice) 的子类表示。返回列表的长度等于 device_count(backend)。可以通过比较 Device.process_indexjax.process_index() 返回的值来识别本地设备。

如果 backendNone,则返回来自默认后端的所有设备。默认后端通常是 'gpu''tpu'(如果可用),否则为 'cpu'

参数:

backend (str | xla_client.Client | None | None) – 这是一个实验性功能,API 可能会更改。 可选,表示 xla 后端的字符串:'cpu', 'gpu', 或 'tpu'

返回:

Device 子类列表。

返回类型:

list[xla_client.Device]