jax.default_device#
- jax.default_device = <jax._src.config.State object>#
Context manager for jax_default_device config option.
Configure the default device for JAX operations. Set to a Device object (e.g.
jax.devices("cpu")[0]
) to use that Device as the default device for JAX operations and jit’d function calls (there is no effect on multi-device computations, e.g. pmapped function calls). Set to None to use the system default device. See Controlling data and computation placement on devices for more information on device placement.- Parameters:
new_val (Any)