Type promotion semantics#

This document describes JAX’s type promotion rules–i.e., the result of jax.numpy.promote_types() for each pair of types. For some background on the considerations that went into the design of what is described below, see Design of Type Promotion Semantics for JAX.

JAX’s type promotion behavior is determined via the following type promotion lattice:

_images/type_lattice.svg

where, for example:

  • b1 means np.bool_,

  • i2 means np.int16,

  • u4 means np.uint32,

  • bf means np.bfloat16,

  • f2 means np.float16,

  • c8 means np.complex64,

  • i* means Python int or weakly-typed int,

  • f* means Python float or weakly-typed float, and

  • c* means Python complex or weakly-typed complex.

(for more about weak types, see Weakly-typed values in JAX below).

Promotion between any two types is given by their join on this lattice, which generates the following binary promotion table:

b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
b1b1u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
u1u1u1u2u4u8i2i2i4i8bff2f4f8c8c16u1f*c*
u2u2u2u2u4u8i4i4i4i8bff2f4f8c8c16u2f*c*
u4u4u4u4u4u8i8i8i8i8bff2f4f8c8c16u4f*c*
u8u8u8u8u8u8f*f*f*f*bff2f4f8c8c16u8f*c*
i1i1i2i4i8f*i1i2i4i8bff2f4f8c8c16i1f*c*
i2i2i2i4i8f*i2i2i4i8bff2f4f8c8c16i2f*c*
i4i4i4i4i8f*i4i4i4i8bff2f4f8c8c16i4f*c*
i8i8i8i8i8f*i8i8i8i8bff2f4f8c8c16i8f*c*
bfbfbfbfbfbfbfbfbfbfbff4f4f8c8c16bfbfc8
f2f2f2f2f2f2f2f2f2f2f4f2f4f8c8c16f2f2c8
f4f4f4f4f4f4f4f4f4f4f4f4f4f8c8c16f4f4c8
f8f8f8f8f8f8f8f8f8f8f8f8f8f8c16c16f8f8c16
c8c8c8c8c8c8c8c8c8c8c8c8c8c16c8c16c8c8c8
c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16c16
i*i*u1u2u4u8i1i2i4i8bff2f4f8c8c16i*f*c*
f*f*f*f*f*f*f*f*f*f*bff2f4f8c8c16f*f*c*
c*c*c*c*c*c*c*c*c*c*c8c8c8c16c8c16c*c*c*

Jax’s type promotion rules differ from those of NumPy, as given by numpy.promote_types(), in those cells highlighted with a green background in the table above. There are three key classes of differences:

  • When promoting a weakly typed value against a typed JAX value of the same category, JAX always prefers the precision of the JAX value. For example, jnp.int16(1) + 1 will return int16 rather than promoting to int64 as in NumPy. Note that this applies only to Python scalar values; if the constant is a NumPy array then the above lattice is used for type promotion. For example, jnp.int16(1) + np.array(1) will return int64.

  • When promoting an integer or boolean type against a floating-point or complex type, JAX always prefers the type of the floating-point or complex type.

  • JAX supports the bfloat16 non-standard 16-bit floating point type (jax.numpy.bfloat16), which is useful for neural network training. The only notable promotion behavior is with respect to IEEE-754 float16, with which bfloat16 promotes to a float32.

The differences between NumPy and JAX are motivated by the fact that accelerator devices, such as GPUs and TPUs, either pay a significant performance penalty to use 64-bit floating point types (GPUs) or do not support 64-bit floating point types at all (TPUs). Classic NumPy’s promotion rules are too willing to overpromote to 64-bit types, which is problematic for a system designed to run on accelerators.

JAX uses floating point promotion rules that are more suited to modern accelerator devices and are less aggressive about promoting floating point types. The promotion rules used by JAX for floating-point types are similar to those used by PyTorch.

Effects of Python operator dispatch#

Keep in mind that Python operators like + will dispatch based on the Python type of the two values being added. This means that, for example, np.int16(1) + 1 will promote using NumPy rules, whereas jnp.int16(1) + 1 will promote using JAX rules. This can lead to potentially confusing non-associative promotion semantics when the two types of promotion are combined; for example with np.int16(1) + 1 + jnp.int16(1).

Weakly-typed values in JAX#

Weakly-typed values in JAX can in most cases be thought of as having promotion behavior equivalent to that of Python scalars, such as the integer scalar 2 in the following:

>>> x = jnp.arange(5, dtype='int8')
>>> 2 * x
Array([0, 2, 4, 6, 8], dtype=int8)

JAX’s weak type framework is designed to prevent unwanted type promotion within binary operations between JAX values and values with no explicitly user-specified type, such as Python scalar literals. For example, if 2 were not treated as weakly-typed, the expression above would lead to an implicit type promotion:

>>> jnp.int32(2) * x
Array([0, 2, 4, 6, 8], dtype=int32)

When used in JAX, Python scalars are sometimes promoted to DeviceArray objects, for example during JIT compilation. To maintain the desired promotion semantics in this case, DeviceArray objects carry a weak_type flag that can be seen in an array’s string representation:

>>> jnp.asarray(2)
Array(2, dtype=int32, weak_type=True)

If the dtype is specified explicitly, it will instead result in a standard strongly-typed array value:

>>> jnp.asarray(2, dtype='int32')
Array(2, dtype=int32)