jax.devices#

jax.devices(backend=None)[源代码]#

返回给定后端的所有设备列表。

每个设备都由 Device 的子类表示(例如,CpuDeviceGpuDevice)。返回列表的长度等于 device_count(backend)。可以通过比较 Device.process_indexjax.process_index() 返回的值来标识本地设备。

如果 backendNone,则返回默认后端的所有设备。默认后端通常是 'gpu''tpu'(如果可用),否则是 'cpu'

参数:

backend (str | xla_client.Client | None) – 这是一个实验性功能,API 可能会发生更改。可选参数,一个字符串,表示 xla 后端:'cpu''gpu''tpu'

返回:

Device 子类列表。

返回类型:

list[xla_client.Device]