jax.devices#
- jax.devices(backend=None)[source]#
Returns a list of all devices for a given backend.
Each device is represented by a subclass of
Device
(e.g.CpuDevice
,GpuDevice
). The length of the returned list is equal todevice_count(backend)
. Local devices can be identified by comparingDevice.process_index
to the value returned byjax.process_index()
.If
backend
isNone
, returns all the devices from the default backend. The default backend is generally'gpu'
or'tpu'
if available, otherwise'cpu'
.- Parameters:
backend (str | xla_client.Client | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend:
'cpu'
,'gpu'
, or'tpu'
.- Return type:
list[xla_client.Device]
- Returns:
List of Device subclasses.