jax.device_count#
- jax.device_count(backend=None)[source]#
Returns the total number of devices.
On most platforms, this is the same as
jax.local_device_count()
. However, on multi-process platforms where different devices are associated with different processes, this will return the total number of devices across all processes.