jax.process_count

Contents

jax.process_count#

jax.process_count(backend=None)[source]#

Returns the number of JAX processes associated with the backend.

Parameters:

backend (str | xla_client.Client | None)

Return type:

int