jax.typing module

jax.typing module#

The JAX typing module is where JAX-specific static type annotations live. This submodule is a work in progress; to see the proposal behind the types exported here, see https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html.

The currently-available types are:

  • jax.Array: annotation for any JAX array or tracer (i.e. representations of arrays within JAX transforms).

  • jax.typing.ArrayLike: annotation for any value that is safe to implicitly cast to a JAX array; this includes jax.Array, numpy.ndarray, as well as Python builtin numeric values (e.g. int, float, etc.) and numpy scalar values (e.g. numpy.int32, numpy.flota64, etc.)

  • jax.typing.DTypeLike: annotation for any value that can be cast to a JAX-compatible dtype; this includes strings (e.g. ‘float32’, ‘int32’), scalar types (e.g. float, np.float32), dtypes (e.g. np.dtype(‘float32’)), or objects with a dtype attribute (e.g. jnp.float32, jnp.int32).

We may add additional types here in future releases.

JAX Typing Best Practices#

When annotating JAX arrays in public API functions, we recommend using ArrayLike for array inputs, and Array for array outputs.

For example, your function might look like this:

import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

def my_function(x: ArrayLike) -> Array:
  # Runtime type validation, Python 3.10 or newer:
  if not isinstance(x, ArrayLike):
    raise TypeError(f"Expected arraylike input; got {x}")
  # Runtime type validation, any Python version:
  if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
    raise TypeError(f"Expected arraylike input; got {x}")

  # Convert input to jax.Array:
  x_arr = jnp.asarray(x)

  # ... do some computation; JAX functions will return Array types:
  result = x_arr.sum(0) / x_arr.shape[0]

  # return an Array
  return result

Most of JAX’s public APIs follow this pattern. Note in particular that we recommend JAX functions to not accept sequences such as list or tuple in place of arrays, as this can cause extra overhead in JAX transforms like jit() and can behave in unexpected ways with batch-wise transforms like vmap() or jax.pmap(). For more information on this, see Non-array inputs NumPy vs JAX

List of Members#


Type annotation for JAX array-like objects.


alias of str | type[Any] | dtype | SupportsDType