jax.experimental.enable_x64#
- jax.experimental.enable_x64(new_val=True)[source]#
Experimental context manager to temporarily enable X64 mode.
Usage:
>>> x = np.arange(5, dtype='float64') >>> with enable_x64(): ... print(jnp.asarray(x).dtype) ... float64
See also
jax.experimental.enable_x64
temporarily enable X64 mode.
- Parameters
new_val (
bool
) –