`_
on this lattice, which generates the following binary promotion table:
.. raw:: html
 b1  u1  u2  u4  u8  i1  i2  i4  i8  bf  f2  f4  f8  c8  c16  i*  f*  c* 
b1  b1  u1  u2  u4  u8  i1  i2  i4  i8  bf  f2  f4  f8  c8  c16  i*  f*  c* 
u1  u1  u1  u2  u4  u8  i2  i2  i4  i8  bf  f2  f4  f8  c8  c16  u1  f*  c* 
u2  u2  u2  u2  u4  u8  i4  i4  i4  i8  bf  f2  f4  f8  c8  c16  u2  f*  c* 
u4  u4  u4  u4  u4  u8  i8  i8  i8  i8  bf  f2  f4  f8  c8  c16  u4  f*  c* 
u8  u8  u8  u8  u8  u8  f*  f*  f*  f*  bf  f2  f4  f8  c8  c16  u8  f*  c* 
i1  i1  i2  i4  i8  f*  i1  i2  i4  i8  bf  f2  f4  f8  c8  c16  i1  f*  c* 
i2  i2  i2  i4  i8  f*  i2  i2  i4  i8  bf  f2  f4  f8  c8  c16  i2  f*  c* 
i4  i4  i4  i4  i8  f*  i4  i4  i4  i8  bf  f2  f4  f8  c8  c16  i4  f*  c* 
i8  i8  i8  i8  i8  f*  i8  i8  i8  i8  bf  f2  f4  f8  c8  c16  i8  f*  c* 
bf  bf  bf  bf  bf  bf  bf  bf  bf  bf  bf  f4  f4  f8  c8  c16  bf  bf  c8 
f2  f2  f2  f2  f2  f2  f2  f2  f2  f2  f4  f2  f4  f8  c8  c16  f2  f2  c8 
f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f4  f8  c8  c16  f4  f4  c8 
f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  f8  c16  c16  f8  f8  c16 
c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c8  c16  c8  c16  c8  c8  c8 
c16  c16  c16  c16  c16  c16  c16  c16  c16  c16  c16  c16  c16  c16  c16  c16  c16  c16  c16 
i*  i*  u1  u2  u4  u8  i1  i2  i4  i8  bf  f2  f4  f8  c8  c16  i*  f*  c* 
f*  f*  f*  f*  f*  f*  f*  f*  f*  f*  bf  f2  f4  f8  c8  c16  f*  f*  c* 
c*  c*  c*  c*  c*  c*  c*  c*  c*  c*  c8  c8  c8  c16  c8  c16  c*  c*  c* 
.. The table above was generated by the following Python code.
import numpy as np
import jax.numpy as jnp
from jax._src import dtypes
types = [np.bool_, np.uint8, np.uint16, np.uint32, np.uint64,
np.int8, np.int16, np.int32, np.int64,
jnp.bfloat16, np.float16, np.float32, np.float64,
np.complex64, np.complex128, int, float, complex]
def name(d):
if d == jnp.bfloat16:
return "bf"
itemsize = "*" if d in {int, float, complex} else np.dtype(d).itemsize
return f"{np.dtype(d).kind}{itemsize}"
out = ""
for t in types:
out += "{}".format(name(t))
out += "\n"
for t1 in types:
out += "{}".format(name(t1))
for t2 in types:
t, weak_type = dtypes._lattice_result_type(t1, t2)
if weak_type:
t = type(t.type(0).item())
different = jnp.bfloat16 in (t1, t2) or jnp.promote_types(t1, t2) is not np.promote_types(t1, t2)
out += "
{}".format(" class=\"d\"" if different else "", name(t))
out += "\n"
print(out)
Jax's type promotion rules differ from those of NumPy, as given by
:func:`numpy.promote_types`, in those cells highlighted with a green background
in the table above. There are three key classes of differences:
* When promoting a weakly typed value against a typed JAX value of the same category,
JAX always prefers the precision of the JAX value. For example, ``jnp.int16(1) + 1``
will return ``int16`` rather than promoting to ``int64`` as in NumPy.
Note that this applies only to Python scalar values; if the constant is a NumPy
array then the above lattice is used for type promotion.
For example, ``jnp.int16(1) + np.array(1)`` will return ``int64``.
* When promoting an integer or boolean type against a floatingpoint or complex
type, JAX always prefers the type of the floatingpoint or complex type.
* JAX supports the
`bfloat16 `_
nonstandard 16bit floating point type
(:code:`jax.numpy.bfloat16`), which is useful for neural network training.
The only notable promotion behavior is with respect to IEEE754
:code:`float16`, with which :code:`bfloat16` promotes to a :code:`float32`.
The differences between NumPy and JAX are motivated by the fact that
accelerator devices, such as GPUs and TPUs, either pay a significant
performance penalty to use 64bit floating point types (GPUs) or do not
support 64bit floating point types at all (TPUs). Classic NumPy's promotion
rules are too willing to overpromote to 64bit types, which is problematic for
a system designed to run on accelerators.
JAX uses floating point promotion rules that are more suited to modern
accelerator devices and are less aggressive about promoting floating point
types. The promotion rules used by JAX for floatingpoint types are similar to
those used by PyTorch.
Effects of Python operator dispatch

Keep in mind that Python operators like `+` will dispatch based on the Python type of
the two values being added. This means that, for example, ``np.int16(1) + 1`` will
promote using NumPy rules, whereas ``jnp.int16(1) + 1`` will promote using JAX rules.
This can lead to potentially confusing nonassociative promotion semantics when
the two types of promotion are combined;
for example with ``np.int16(1) + 1 + jnp.int16(1)``.
.. _weaktypes:
Weaklytyped values in JAX

*Weaklytyped* values in JAX can in most cases be thought of as having promotion behavior
equivalent to that of Python scalars, such as the integer scalar ``2`` in the following:
.. codeblock:: python
>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8)
JAX's weak type framework is designed to prevent unwanted type promotion within
binary operations between JAX values and values with no explicitly userspecified type,
such as Python scalar literals. For example, if ``2`` were not treated as weaklytyped,
the expression above would lead to an implicit type promotion:
.. codeblock:: python
>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32)
When used in JAX, Python scalars are sometimes promoted to :class:`~jax.numpy.DeviceArray`
objects, for example during JIT compilation. To maintain the desired promotion
semantics in this case, :class:`~jax.numpy.DeviceArray` objects carry a ``weak_type`` flag
that can be seen in an array's string representation:
.. codeblock:: python
>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True)
If the ``dtype`` is specified explicitly, it will instead result in a standard
stronglytyped array value:
.. codeblock:: python
>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)
.. _strictdtypepromotion:
Strict dtype promotion

In some contexts it can be useful to disable implicit type promotion behavior, and
instead require all promotions to be explicit. This can be done in JAX by setting the
``jax_numpy_dtype_promtion`` flag to ``'strict'``. Locally, it can be done with a\
context manager:
.. codeblock:: python
>>> x = jnp.float32(1)
>>> y = jnp.int32(1)
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + y # doctest: +IGNORE_EXCEPTION_DETAIL
...
Traceback (most recent call last):
TypePromotionError: Input dtypes ('float32', 'int32') have no available implicit
dtype promotion path when jax_numpy_dtype_promotion=strict. Try explicitly casting
inputs to the desired output type, or set jax_numpy_dtype_promotion=standard.
For convenience, strict promotion mode will still allow safe weaklytyped promotions,
so you can still write code code that mixes JAX arrays and Python scalars:
.. codeblock:: python
>>> with jax.numpy_dtype_promotion('strict'):
... z = x + 1
>>> print(z)
2.0
If you would prefer to set the configuration globally, you can do so using the standard
configuration update::
jax.config.update('jax_numpy_dtype_promotion', 'strict')
To restore the default standard type promotion, set this configuration to ``'standard'``::
jax.config.update('jax_numpy_dtype_promotion', 'standard')
