jax.default_matmul_precision

jax.default_matmul_precision#

jax.default_matmul_precision = <jax._src.config._StateContextManager object>#

Context manager for jax_default_matmul_precision config option.

Control the default matmul and conv precision for 32bit inputs.

Some platforms, like TPU, offer configurable precision levels for matrix multiplication and convolution computations, trading off accuracy for speed. The precision can be controlled for each operation; for example, see the jax.lax.conv_general_dilated() and jax.lax.dot() docstrings. But it can be useful to control the default behavior obtained when an operation is not given a specific precision.

This option can be used to control the default precision level for computations involved in matrix multiplication and convolution on 32bit inputs. The levels roughly describe the precision at which scalar products are computed. The ‘bfloat16’ option is the fastest and least precise; ‘float32’ is similar to full float32 precision; ‘tensorfloat32’ is intermediate.