jax.devices#
- jax.devices(backend=None)[source]#
Returns a list of all devices for a given backend.
Each device is represented by a subclass of
Device
(e.g.CpuDevice
,GpuDevice
). The length of the returned list is equal todevice_count(backend)
. Local devices can be identified by comparingDevice.process_index
to the value returned byjax.process_index()
.If
backend
isNone
, returns all the devices from the default backend. The default backend is generally'gpu'
or'tpu'
if available, otherwise'cpu'
.