jax.default_device

Contents

jax.default_device#

jax.default_device = <jax._src.config._StateContextManager object>#

Context manager for jax_default_device config option.

Configure the default device for JAX operations. Set to a Device object (e.g. jax.devices("cpu")[0]) to use that Device as the default device for JAX operations and jit’d function calls (there is no effect on multi-device computations, e.g. pmapped function calls). Set to None to use the system default device. See Controlling data and computation placement on devices for more information on device placement.