jax.local_devices#

jax.local_devices(process_index=None, backend=None, host_id=None)[source]#

Like jax.devices(), but only returns devices local to a given process.

If process_index is None, returns devices local to this process.

Parameters:
  • process_index (Optional[int]) – the integer index of the process. Process indices can be retrieved via len(jax.process_count()).

  • backend (Union[str, Client, None]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: 'cpu', 'gpu', or 'tpu'.

  • host_id (Optional[int]) –

Return type:

list[Device]

Returns:

List of Device subclasses.