class jax.Array#

Array base class for JAX

jax.Array is the public interface for instance checks and type annotation of JAX arrays and tracers. Its main applications are in instance checks and type annotations; for example:

x = jnp.arange(5)
isinstance(x, jax.Array)  # returns True both inside and outside traced functions.

def f(x: Array) -> Array:  # type annotations are valid for traced and non-traced types.
  return x

jax.Array should not be used directly for creation of arrays; instead you should use array creation routines offered in jax.numpy, such as jax.numpy.array(), jax.numpy.zeros(), jax.numpy.ones(), jax.numpy.full(), jax.numpy.arange(), etc.




all([axis, out, keepdims, where])

Test whether all array elements along a given axis evaluate to True.

any([axis, out, keepdims, where])

Test whether any array element along a given axis evaluates to True.

argmax([axis, out, keepdims])

Returns the indices of the maximum values along an axis.

argmin([axis, out, keepdims])

Returns the indices of the minimum values along an axis.

argpartition(kth[, axis])

Perform an indirect partition along the given axis using the

argsort([axis, kind, order])

Returns the indices that would sort an array.


Copy the array and cast to a specified dtype.

choose(choices[, out, mode])

Construct an array from an index array and a list of arrays to choose from.

clip([min, max, out])

Return an array whose values are limited to a specified range.

compress(condition[, axis, out])

Return selected slices of this array along given axis.


Return the complex conjugate, element-wise.


Return the complex conjugate, element-wise.


Return an array copy of the given object.

cumprod([axis, dtype, out])

Return the cumulative product of elements along a given axis.

cumsum([axis, dtype, out])

Return the cumulative sum of the elements along a given axis.

diagonal([offset, axis1, axis2])

Return specified diagonals.

dot(b, *[, precision, preferred_element_type])

Dot product of two arrays.


Return a contiguous flattened array.


Copy an element of an array to a standard Python scalar and return it.

max([axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

mean([axis, dtype, out, keepdims, where])

Compute the arithmetic mean along the specified axis.

min([axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

nonzero(*[, size, fill_value])

Return the indices of the elements that are non-zero.

prod([axis, dtype, out, keepdims, initial, ...])

Return the product of array elements over a given axis.

ptp([axis, out, keepdims])

Range of values (maximum - minimum) along an axis.


Return a contiguous flattened array.

repeat(repeats[, axis, total_repeat_length])

Repeat each element of an array after themselves

reshape(*args[, order])

Returns an array containing the same data with a new shape.

round([decimals, out])

Round an array to the given number of decimals.

searchsorted(v[, side, sorter, method])

Find indices where elements should be inserted to maintain order.

sort([axis, kind, order])

Return a sorted copy of an array.


Remove axes of length one from a.

std([axis, dtype, out, ddof, keepdims, where])

Compute the standard deviation along the specified axis.

sum([axis, dtype, out, keepdims, initial, ...])

Sum of array elements over a given axis.

swapaxes(axis1, axis2)

Interchange two axes of an array.

take(indices[, axis, out, mode, ...])

Take elements from an array along an axis.

trace([offset, axis1, axis2, dtype, out])

Return the sum along diagonals of the array.


Returns a view of the array with axes transposed.

var([axis, dtype, out, ddof, keepdims, where])

Compute the variance along the specified axis.

view([dtype, type])

Return a bitwise copy of the array, viewed as a new dtype.



Returns an array with axes transposed.


Helper property for index update functionality.



Return the imaginary part of the complex argument.


Length of one array element in bytes.


Transposes the last two dimensions of x.


Total bytes consumed by the elements of the array.


Return the real part of the complex argument.