jax.experimental package

jax.experimental.optix has been moved into its own Python package (optax).

jax.experimental.enable_x64(new_val=True)[source]

Experimental context manager to temporarily enable X64 mode.

Usage:

>>> import jax.numpy as jnp
>>> with enable_x64():
...   print(jnp.arange(10.0).dtype)
...
float64

See also

jax.experimental.enable_x64

temporarily enable X64 mode.

Parameters

new_val (bool) –

jax.experimental.disable_x64()[source]

Experimental context manager to temporarily disable X64 mode.

Usage:

>>> import jax.numpy as jnp
>>> with disable_x64():
...   print(jnp.arange(10.0).dtype)
...
float32

See also

jax.experimental.enable_x64

temporarily enable X64 mode.