jax.devices

Contents

jax.devices#

jax.devices(backend=None)[source]#

Returns a list of all devices for a given backend.

Each device is represented by a subclass of Device (e.g. CpuDevice, GpuDevice). The length of the returned list is equal to device_count(backend). Local devices can be identified by comparing Device.process_index to the value returned by jax.process_index().

If backend is None, returns all the devices from the default backend. The default backend is generally 'gpu' or 'tpu' if available, otherwise 'cpu'.

Parameters:

backend (str | xla_client.Client | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: 'cpu', 'gpu', or 'tpu'.

Return type:

list[xla_client.Device]

Returns:

List of Device subclasses.