jax.device_count

Contents

jax.device_count#

jax.device_count(backend=None)[source]#

Returns the total number of devices.

On most platforms, this is the same as jax.local_device_count(). However, on multi-process platforms where different devices are associated with different processes, this will return the total number of devices across all processes.

Parameters:

backend (str | xla_client.Client | None | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: 'cpu', 'gpu', or 'tpu'.

Returns:

Number of devices.

Return type:

int