jax.local_device_count

jax.local_device_count#

jax.local_device_count(backend=None)[source]#

Returns the number of devices addressable by this process.

Parameters:

backend (str | xla_client.Client | None) –

Return type:

int