Type promotion semantics

JAX’s type promotion rules (i.e., the result of jax.numpy.promote_types() for each pair of types) is determined via the following type promotion lattice:

_images/type_lattice.svg

where, for example:

  • b1 means np.bool_,

  • i2 means np.int16,

  • u4 means np.uint32,

  • bf means np.bfloat16,

  • f2 means np.float16,

  • c8 means np.complex128,

  • i* means Python int,

  • f* means Python float, and

  • c* means Python complex.

Promotion between any two types is given by their join on this lattice, which generates the following binary promotion table:

b1u1u2u4u8i1i2i4i8bff2f4f8c4c8i*f*c*
b1b1u1u2u4u8i1i2i4i8bff2f4f8c4c8i8f8c8
u1u1u1u2u4u8i2i2i4i8bff2f4f8c4c8u1f8c8
u2u2u2u2u4u8i4i4i4i8bff2f4f8c4c8u2f8c8
u4u4u4u4u4u8i8i8i8i8bff2f4f8c4c8u4f8c8
u8u8u8u8u8u8f8f8f8f8bff2f4f8c4c8u8f8c8
i1i1i2i4i8f8i1i2i4i8bff2f4f8c4c8i1f8c8
i2i2i2i4i8f8i2i2i4i8bff2f4f8c4c8i2f8c8
i4i4i4i4i8f8i4i4i4i8bff2f4f8c4c8i4f8c8
i8i8i8i8i8f8i8i8i8i8bff2f4f8c4c8i8f8c8
bfbfbfbfbfbfbfbfbfbfbff4f4f8c4c8bfbfc4
f2f2f2f2f2f2f2f2f2f2f4f2f4f8c4c8f2f2c4
f4f4f4f4f4f4f4f4f4f4f4f4f4f8c4c8f4f4c4
f8f8f8f8f8f8f8f8f8f8f8f8f8f8c8c8f8f8c8
c4c4c4c4c4c4c4c4c4c4c4c4c4c8c4c8c4c4c4
c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8
i*i8u1u2u4u8i1i2i4i8bff2f4f8c4c8i8f8c8
f*f8f8f8f8f8f8f8f8f8bff2f4f8c4c8f8f8c8
c*c8c8c8c8c8c8c8c8c8c4c4c4c8c4c8c8c8c8

Jax’s type promotion rules differ from those of NumPy, as given by numpy.promote_types(), in those cells highlighted with a green background in the table above. There are three key differences:

  • when promoting a Python scalar value against a typed JAX value of the same category, JAX always prefers the precision of the JAX value. For example, jnp.int16(1) + 1 will return int16 rather than promoting to int64 as in Numpy.

  • when promoting an integer or boolean type against a floating-point or complex type, JAX always prefers the type of the floating-point or complex type.

  • JAX supports the bfloat16 non-standard 16-bit floating point type (jax.numpy.bfloat16), which is useful for neural network training. The only notable promotion behavior is with respect to IEEE-754 float16, with which bfloat16 promotes to a float32.

These differences are motivated by the fact that accelerator devices, such as GPUs and TPUs, either pay a significant performance penalty to use 64-bit floating point types (GPUs) or do not support 64-bit floating point types at all (TPUs). Classic NumPy’s promotion rules are too willing to overpromote to 64-bit types, which is problematic for a system designed to run on accelerators.

JAX uses floating point promotion rules that are more suited to modern accelerator devices and are less aggressive about promoting floating point types. The promotion rules used by JAX for floating-point types are similar to those used by PyTorch.