class jax.Array#

Array base class for JAX

jax.Array is the public interface for instance checks and type annotation of JAX arrays and tracers. Its main applications are in instance checks and type annotations; for example:

x = jnp.arange(5)
isinstance(x, jax.Array)  # returns True both inside and outside traced functions.

def f(x: Array) -> Array:  # type annotations are valid for traced and non-traced types.
  return x

jax.Array should not be used directly for creation of arrays; instead you should use array creation routines offered in jax.numpy, such as jax.numpy.array(), jax.numpy.zeros(), jax.numpy.ones(), jax.numpy.full(), jax.numpy.arange(), etc.






Helper property for index update functionality.