jax.local_devices

Contents

jax.local_devices#

jax.local_devices(process_index=None, backend=None, host_id=None)[source]#

Like jax.devices(), but only returns devices local to a given process.

If process_index is None, returns devices local to this process.

Parameters:
  • process_index (int | None) – the integer index of the process. Process indices can be retrieved via len(jax.process_count()).

  • backend (str | xla_client.Client | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: 'cpu', 'gpu', or 'tpu'.

  • host_id (int | None)

Return type:

list[xla_client.Device]

Returns:

List of Device subclasses.