jax.Array#
- class jax.Array#
Array base class for JAX
jax.Array
is the public interface for instance checks and type annotation of JAX arrays and tracers. Its main applications are in instance checks and type annotations; for example:x = jnp.arange(5) isinstance(x, jax.Array) # returns True both inside and outside traced functions. def f(x: Array) -> Array: # type annotations are valid for traced and non-traced types. return x
jax.Array
should not be used directly for creation of arrays; instead you should use array creation routines offered injax.numpy
, such asjax.numpy.array()
,jax.numpy.zeros()
,jax.numpy.ones()
,jax.numpy.full()
,jax.numpy.arange()
, etc.- __init__()#
Methods
__init__
()Attributes
at
Helper property for index update functionality.