# Type promotion semanticsÂ¶

JAXâ€™s type promotion rules (i.e., the result of
`jax.numpy.promote_types()`

for each pair of types) is determined via
the following type promotion lattice:

where, for example:

`b1`

means`np.bool_`

,`i2`

means`np.int16`

,`u4`

means`np.uint32`

,`bf`

means`np.bfloat16`

,`f2`

means`np.float16`

,`c8`

means`np.complex128`

,`i*`

means Python`int`

,`f*`

means Python`float`

, and`c*`

means Python`complex`

.

Promotion between any two types is given by their join on this lattice, which generates the following binary promotion table:

b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 | i* | f* | c* | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|

b1 | b1 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 | i8 | f8 | c8 |

u1 | u1 | u1 | u2 | u4 | u8 | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 | u1 | f8 | c8 |

u2 | u2 | u2 | u2 | u4 | u8 | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 | u2 | f8 | c8 |

u4 | u4 | u4 | u4 | u4 | u8 | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c4 | c8 | u4 | f8 | c8 |

u8 | u8 | u8 | u8 | u8 | u8 | f8 | f8 | f8 | f8 | bf | f2 | f4 | f8 | c4 | c8 | u8 | f8 | c8 |

i1 | i1 | i2 | i4 | i8 | f8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 | i1 | f8 | c8 |

i2 | i2 | i2 | i4 | i8 | f8 | i2 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 | i2 | f8 | c8 |

i4 | i4 | i4 | i4 | i8 | f8 | i4 | i4 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 | i4 | f8 | c8 |

i8 | i8 | i8 | i8 | i8 | f8 | i8 | i8 | i8 | i8 | bf | f2 | f4 | f8 | c4 | c8 | i8 | f8 | c8 |

bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | bf | f4 | f4 | f8 | c4 | c8 | bf | bf | c4 |

f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f2 | f4 | f2 | f4 | f8 | c4 | c8 | f2 | f2 | c4 |

f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f4 | f8 | c4 | c8 | f4 | f4 | c4 |

f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | c8 | c8 | f8 | f8 | c8 |

c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c4 | c8 | c4 | c8 | c4 | c4 | c4 |

c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 |

i* | i8 | u1 | u2 | u4 | u8 | i1 | i2 | i4 | i8 | bf | f2 | f4 | f8 | c4 | c8 | i8 | f8 | c8 |

f* | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | f8 | bf | f2 | f4 | f8 | c4 | c8 | f8 | f8 | c8 |

c* | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c8 | c4 | c4 | c4 | c8 | c4 | c8 | c8 | c8 | c8 |

Jaxâ€™s type promotion rules differ from those of NumPy, as given by
`numpy.promote_types()`

, in those cells highlighted with a green background
in the table above. There are three key differences:

When promoting a Python scalar value against a typed JAX value of the same category, JAX always prefers the precision of the JAX value. For example,

`jnp.int16(1) + 1`

will return`int16`

rather than promoting to`int64`

as in NumPy. Note that this applies only to Python scalar values; if the constant is a NumPy array then the above lattice is used for type promotion. For example,`jnp.int16(1) + np.array(1)`

will return`int64`

.When promoting an integer or boolean type against a floating-point or complex type, JAX always prefers the type of the floating-point or complex type.

JAX supports the bfloat16 non-standard 16-bit floating point type (

`jax.numpy.bfloat16`

), which is useful for neural network training. The only notable promotion behavior is with respect to IEEE-754`float16`

, with which`bfloat16`

promotes to a`float32`

.

The differences between NumPy and JAX are motivated by the fact that accelerator devices, such as GPUs and TPUs, either pay a significant performance penalty to use 64-bit floating point types (GPUs) or do not support 64-bit floating point types at all (TPUs). Classic NumPyâ€™s promotion rules are too willing to overpromote to 64-bit types, which is problematic for a system designed to run on accelerators.

JAX uses floating point promotion rules that are more suited to modern accelerator devices and are less aggressive about promoting floating point types. The promotion rules used by JAX for floating-point types are similar to those used by PyTorch.

Note that operators like + will dispatch based on the Python type of the two values being added. This means that, for example, np.int16(1) + 1 will promote using NumPy rules, whereas jnp.int16(1) + 1 will promote using JAX rules. This can lead to potentially confusing non-associative promotion semantics when the two types of promotion are combined; for example with np.int16(1) + 1 + jnp.int16(1).