jax.local_devices#
- jax.local_devices(process_index=None, backend=None, host_id=None)[source]#
Like
jax.devices()
, but only returns devices local to a given process.If
process_index
isNone
, returns devices local to this process.- Parameters:
process_index (
Optional
[int
]) – the integer index of the process. Process indices can be retrieved vialen(jax.process_count())
.backend (
Union
[str
,Client
,None
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend:'cpu'
,'gpu'
, or'tpu'
.
- Return type:
- Returns:
List of Device subclasses.