Returns a list of all devices for a given backend.

Each device is represented by a subclass of Device (e.g. CpuDevice, GpuDevice). The length of the returned list is equal to device_count(backend). Local devices can be identified by comparing Device.process_index to the value returned by jax.process_index().

If backend is None, returns all the devices from the default backend. The default backend is generally 'gpu' or 'tpu' if available, otherwise 'cpu'.


backend (Union[str, Client, None]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: 'cpu', 'gpu', or 'tpu'.

Return type



List of Device subclasses.