jax.device_count#
- jax.device_count(backend=None)[source]#
Returns the total number of devices.
On most platforms, this is the same as
jax.local_device_count()
. However, on multi-process platforms where different devices are associated with different processes, this will return the total number of devices across all processes.- Parameters:
backend (str | xla_client.Client | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend:
'cpu'
,'gpu'
, or'tpu'
.- Return type:
int
- Returns:
Number of devices.