jax.numpy package
Contents
jax.numpy package#
Implements the NumPy API, using the primitives in jax.lax
.
While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.
Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays inplace cannot be implemented in JAX. However, often JAX is able to provide an alternative API that is purely functional. For example, instead of inplace array updates (
x[i] = y
), JAX provides an alternative pure indexed update functionx.at[i].set(y)
(seendarray.at
).Relatedly, some NumPy functions often return views of arrays when possible (examples are
transpose()
andreshape()
). JAX versions of such functions will return copies instead, although such are often optimized away by XLA when sequences of operations are compiled usingjax.jit()
.NumPy is very aggressive at promoting values to
float64
type. JAX sometimes is less aggressive about type promotion (See Type promotion semantics).Some NumPy routines have datadependent output shapes (examples include
unique()
andnonzero()
). Because the XLA compiler requires array shapes to be known at compile time, such operations are not compatible with JIT. For this reason, JAX adds an optionalsize
argument to such functions which may be specified statically in order to use them with JIT.
Nearly all applicable NumPy functions are implemented in the jax.numpy
namespace; they are listed below.
Helper property for index update functionality. 


Calculate the absolute value elementwise. 

Calculate the absolute value elementwise. 

Add arguments elementwise. 

Test whether all array elements along a given axis evaluate to True. 

Returns True if two arrays are elementwise equal within a tolerance. 

Test whether all array elements along a given axis evaluate to True. 

Return the maximum of an array or maximum along an axis. 

Return the minimum of an array or minimum along an axis. 

Return the angle of the complex argument. 

Test whether any array element along a given axis evaluates to True. 

Append values to the end of an array. 

Apply a function to 1D slices along the given axis. 

Apply a function repeatedly over multiple axes. 

Return evenly spaced values within a given interval. 

Trigonometric inverse cosine, elementwise. 

Inverse hyperbolic cosine, elementwise. 

Inverse sine, elementwise. 

Inverse hyperbolic sine elementwise. 

Trigonometric inverse tangent, elementwise. 

Elementwise arc tangent of 

Inverse hyperbolic tangent elementwise. 

Returns the indices of the maximum values along an axis. 

Returns the indices of the minimum values along an axis. 

Returns the indices that would sort an array. 

Find the indices of array elements that are nonzero, grouped by element. 

Evenly round to the given number of decimals. 

Create an array. 

True if two arrays have the same shape and elements, False otherwise. 

Returns True if input arrays are shape consistent and all elements equal. 

Return the string representation of an array. 

Split an array into multiple subarrays. 

Return a string representation of the data in an array. 

Convert the input to an array. 

Convert inputs to arrays with at least one dimension. 

View inputs as arrays with at least two dimensions. 

View inputs as arrays with at least three dimensions. 

Compute the weighted average along the specified axis. 

Return the Bartlett window. 

Count number of occurrences of each value in array of nonnegative ints. 

Compute the bitwise AND of two arrays elementwise. 

Compute bitwise inversion, or bitwise NOT, elementwise. 

Compute the bitwise OR of two arrays elementwise. 

Compute the bitwise XOR of two arrays elementwise. 

Return the Blackman window. 

Assemble an ndarray from nested lists of blocks. 



Broadcast any number of arrays against each other. 
Broadcast the input shapes into a single shape. 


Broadcast an array to a new shape. 
Concatenate slices, scalars and arraylike objects along the last axis. 


Returns True if cast between data types can occur according to the casting rule. 

Return the cuberoot of an array, elementwise. 
alias of 


Return the ceiling of the input, elementwise. 
Abstract base class of all character string scalar types. 


Construct an array from an index array and a list of arrays to choose from. 

Clip (limit) the values in an array. 

Stack 1D arrays as columns into a 2D array. 
alias of 





Abstract base class of all complex number scalar types that are made up of floatingpoint numbers. 

The warning raised when casting a complex dtype to a real dtype. 


Return selected slices of an array along given axis. 

Join a sequence of arrays along an existing axis. 

Return the complex conjugate, elementwise. 

Return the complex conjugate, elementwise. 

Returns the discrete, linear convolution of two onedimensional sequences. 

Return an array copy of the given object. 

Change the sign of x1 to that of x2, elementwise. 

Return Pearson productmoment correlation coefficients. 

Crosscorrelation of two 1dimensional sequences. 

Cosine elementwise. 

Hyperbolic cosine, elementwise. 

Counts the number of nonzero values in the array 

Estimate a covariance matrix, given data and weights. 

Return the cross product of two (arrays of) vectors. 
alias of 


Return the cumulative product of elements along a given axis. 

Return the cumulative product of elements along a given axis. 

Return the cumulative sum of the elements along a given axis. 

Convert angles from degrees to radians. 

Convert angles from radians to degrees. 

Return a new array with subarrays along an axis deleted. 

Extract a diagonal or construct a diagonal array. 

Return the indices to access the main diagonal of an array. 

Return the indices to access the main diagonal of an ndimensional array. 

Create a twodimensional array with the flattened input as a diagonal. 

Return specified diagonals. 

Calculate the nth discrete difference along the given axis. 

Return the indices of the bins to which each value in input array belongs. 

Divide arguments elementwise. 

Return elementwise quotient and remainder simultaneously. 

Dot product of two arrays. 
alias of 


Split array into multiple subarrays along the 3rd axis (depth). 

Stack arrays in sequence depth wise (along third axis). 

Create a data type object. 

The differences between consecutive elements of an array. 

Evaluates the Einstein summation convention on the operands. 

Evaluates the lowest cost contraction order for an einsum expression by 

Return a new array of given shape and type, without initializing entries. 

Return a new array with the same shape and type as a given array. 

Return (x1 == x2) elementwise. 

Calculate the exponential of all elements in the input array. 

Calculate 2**p for all p in the input array. 

Expand the shape of an array. 

Calculate 

Return the elements of an array that satisfy some condition. 

Return a 2D array with ones on the diagonal and zeros elsewhere. 

Compute the absolute values elementwise. 

Machine limits for floating point types. 

Round to nearest integer towards zero. 

Return indices that are nonzero in the flattened version of a. 

Abstract base class of all scalar types without predefined length. 

Reverse the order of elements in an array along the given axis. 

Reverse the order of elements along axis 1 (left/right). 

Reverse the order of elements along axis 0 (up/down). 
alias of 


First array elements raised to powers from second array, elementwise. 







Abstract base class of all floatingpoint scalar types. 

Return the floor of the input, elementwise. 

Return the largest integer smaller or equal to the division of the inputs. 

Elementwise maximum of array elements. 

Elementwise minimum of array elements. 

Returns the elementwise remainder of division. 

Decompose the elements of x into mantissa and twos exponent. 

Interpret a buffer as a 1dimensional array. 

Unimplemented JAX wrapper for jnp.fromfile. 

Construct an array by executing a function over each coordinate. 

Unimplemented JAX wrapper for jnp.fromiter. 

A new 1D array initialized from text data in a string. 

Create a NumPy array from an object implementing the 

Return a new array of given shape and type, filled with fill_value. 

Return a full array with the same shape and type as a given array. 

Returns the greatest common divisor of 

Base class for numpy scalar types. 

Return numbers spaced evenly on a log scale (a geometric progression). 
Return the current print options. 


Return the gradient of an Ndimensional array. 

Return the truth value of (x1 > x2) elementwise. 

Return the truth value of (x1 >= x2) elementwise. 

Return the Hamming window. 

Return the Hanning window. 

Compute the Heaviside step function. 

Compute the histogram of a dataset. 

Function to calculate only the edges of the bins used by the histogram 

Compute the bidimensional histogram of two data samples. 

Compute the multidimensional histogram of some data. 

Split an array into multiple subarrays horizontally (columnwise). 

Stack arrays in sequence horizontally (column wise). 

Given the "legs" of a right triangle, return its hypotenuse. 

Modified Bessel function of the first kind, order 0. 

Return the identity array. 

Machine limits for integer types. 

Return the imaginary part of the complex argument. 

Test whether each element of a 1D array is also present in a second array. 
A nicer way to build up index tuples for arrays. 


Return an array representing the indices of a grid. 

Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floatingpoint numbers. 

Inner product of two arrays. 

Insert values along the given axis before the given indices. 
alias of 










Abstract base class of all integer scalar types. 

Onedimensional linear interpolation for monotonically increasing sample points. 

Find the intersection of two arrays. 

Compute bitwise inversion, or bitwise NOT, elementwise. 

Returns a boolean array where two arrays are elementwise equal within a 

Returns a bool array, where True if input element is complex. 

Check for a complex type or an array of complex numbers. 

Test elementwise for finiteness (not infinity and not Not a Number). 

Calculates 

Test elementwise for positive or negative infinity. 

Test elementwise for NaN and return result as a boolean array. 

Test elementwise for negative infinity, return result as bool array. 

Test elementwise for positive infinity, return result as bool array. 

Returns a bool array, where True if input element is real. 

Return True if x is a not complex type or an array of complex numbers. 

Returns True if the type of element is a scalar type. 

Returns True if first argument is a typecode lower/equal in type hierarchy. 

Determine if the first argument is a subclass of the second argument. 

Check whether or not an object can be iterated over. 

Construct an open mesh from multiple sequences. 

Return the Kaiser window. 

Kronecker product of two arrays. 

Returns the lowest common multiple of 

Returns x1 * 2**x2, elementwise. 

Shift the bits of an integer to the left. 

Return the truth value of (x1 < x2) elementwise. 

Return the truth value of (x1 <= x2) elementwise. 

Perform an indirect stable sort using a sequence of keys. 

Return evenly spaced numbers over a specified interval. 

Load arrays or pickled objects from 

Natural logarithm, elementwise. 

Return the base 10 logarithm of the input array, elementwise. 

Return the natural logarithm of one plus the input array, elementwise. 

Base2 logarithm of x. 

Logarithm of the sum of exponentiations of the inputs. 

Logarithm of the sum of exponentiations of the inputs in base2. 

Compute the truth value of x1 AND x2 elementwise. 

Compute the truth value of NOT x elementwise. 

Compute the truth value of x1 OR x2 elementwise. 

Compute the truth value of x1 XOR x2, elementwise. 

Return numbers spaced evenly on a log scale. 

Return the indices to access (n, n) arrays, given a masking function. 

Matrix product of two arrays. 

Return the maximum of an array or maximum along an axis. 

Elementwise maximum of array elements. 

Compute the arithmetic mean along the specified axis. 

Compute the median along the specified axis. 

Return coordinate matrices from coordinate vectors. 
Return dense multidimensional "meshgrid". 


Return the minimum of an array or minimum along an axis. 

Elementwise minimum of array elements. 

Returns the elementwise remainder of division. 

Return the fractional and integral parts of an array, elementwise. 

Move axes of an array to new positions. 

Return a copy of an array sorted along the first axis. 

Multiply arguments elementwise. 

Replace NaN with zero and infinity with large finite numbers (default 

Return the indices of the maximum values in the specified axis ignoring 

Return the indices of the minimum values in the specified axis ignoring 

Return the cumulative product of array elements over a given axis treating Not a 

Return the cumulative sum of array elements over a given axis treating Not a 

Return the maximum of an array or maximum along an axis, ignoring any 

Compute the arithmetic mean along the specified axis, ignoring NaNs. 

Compute the median along the specified axis, while ignoring NaNs. 

Return minimum of an array or minimum along an axis, ignoring any NaNs. 

Compute the qth percentile of the data along the specified axis, 

Return the product of array elements over a given axis treating Not a 

Compute the qth quantile of the data along the specified axis, 

Compute the standard deviation along the specified axis, while 

Return the sum of array elements over a given axis treating Not a 

Compute the variance along the specified axis, while ignoring NaNs. 
alias of 


Return the number of dimensions of an array. 

Numerical negative, elementwise. 

Return the next floatingpoint value after x1 towards x2, elementwise. 

Return the indices of the elements that are nonzero. 

Return (x1 != x2) elementwise. 

Abstract base class of all numeric scalar types. 
Any Python object. 

Return open multidimensional "meshgrid". 


Return a new array of given shape and type, filled with ones. 

Return an array of ones with the same shape and type as a given array. 

Compute the outer product of two vectors. 

Packs the elements of a binaryvalued array into bits in a uint8 array. 

Pad an array. 

Compute the qth percentile of the data along the specified axis. 

Evaluate a piecewisedefined function. 

Change elements of an array based on conditional and input values. 

Find the coefficients of a polynomial with the given sequence of roots. 

Find the sum of two polynomials. 

Return the derivative of the specified order of a polynomial. 

Returns the quotient and remainder of polynomial division. 

Least squares polynomial fit. 

Return an antiderivative (indefinite integral) of a polynomial. 

Find the product of two polynomials. 

Difference (subtraction) of two polynomials. 

Evaluate a polynomial at specific values. 

Numerical positive, elementwise. 

First array elements raised to powers from second array, elementwise. 

Context manager for setting print options. 

Return the product of array elements over a given axis. 

Return the product of array elements over a given axis. 

Returns the type to which a binary operation should cast its arguments. 

Range of values (maximum  minimum) along an axis. 

Replaces specified elements of an array with given values. 

Compute the qth quantile of the data along the specified axis. 
Concatenate slices, scalars and arraylike objects along the first axis. 


Convert angles from radians to degrees. 

Convert angles from degrees to radians. 

Return a contiguous flattened array. 

Converts a tuple of index arrays into an array of flat 

Return the real part of the complex argument. 

Return the reciprocal of the argument, elementwise. 

Returns the elementwise remainder of division. 

Repeat elements of an array. 

Gives a new shape to an array without changing its data. 

Return a new array with the specified shape. 

Returns the type that results from applying the NumPy 

Shift the bits of an integer to the right. 

Round elements of the array to the nearest integer. 

Roll array elements along a given axis. 

Roll the specified axis backwards, until it lies in a given position. 

Return the roots of a polynomial with coefficients given in p. 

Rotate an array by 90 degrees in the plane specified by axes. 

Evenly round to the given number of decimals. 

Evenly round to the given number of decimals. 

Stack arrays in sequence vertically (row wise). 
A nicer way to build up index tuples for arrays. 


Save an array to a binary file in NumPy 

Save several arrays into a single file in uncompressed 

Find indices where elements should be inserted to maintain order. 

Return an array drawn from elements in choicelist, depending on conditions. 

Set printing options. 

Find the set difference of two arrays. 

Find the set exclusiveor of two arrays. 

Return the shape of an array. 

Returns an elementwise indication of the sign of a number. 

Returns elementwise True where signbit is set (less than zero). 
Abstract base class of all signed integer scalar types. 


Trigonometric sine, elementwise. 

Return the normalized sinc function. 
alias of 


Hyperbolic sine, elementwise. 

Return the number of elements along a given axis. 

Test whether any array element along a given axis evaluates to True. 

Return a sorted copy of an array. 

Sort a complex array using the real part first, then the imaginary part. 

Split an array into multiple subarrays as views into ary. 

Return the nonnegative squareroot of an array, elementwise. 

Return the elementwise square of the input. 

Remove axes of length one from a. 

Join a sequence of arrays along a new axis. 

Compute the standard deviation along the specified axis. 

Subtract arguments, elementwise. 

Sum of array elements over a given axis. 

Interchange two axes of an array. 

Take elements from an array along an axis. 

Take values from the input array by matching 1d index and data slices. 

Compute tangent elementwise. 

Compute hyperbolic tangent elementwise. 

Compute tensor dot product along specified axes. 

Construct an array by repeating A the number of times given by reps. 

Return the sum along diagonals of the array. 

Reverse or permute the axes of an array; returns the modified array. 

Integrate along the given axis using the composite trapezoidal rule. 

An array with ones at and below the given diagonal and zeros elsewhere. 

Lower triangle of an array. 

Return the indices for the lowertriangle of an (n, m) array. 

Return the indices for the lowertriangle of arr. 

Trim the leading and/or trailing zeros from a 1D array or sequence. 

Upper triangle of an array. 

Return the indices for the uppertriangle of an (n, m) array. 

Return the indices for the uppertriangle of arr. 

Divide arguments elementwise. 

Return the truncated value of the input, elementwise. 
alias of 










Find the union of two arrays. 

Find the unique elements of an array. 

Unpacks elements of a uint8 array into a binaryvalued output array. 

Converts a flat index or array of flat indices into a tuple 
Abstract base class of all unsigned integer scalar types. 


Unwrap by taking the complement of large deltas with respect to the period. 

Generate a Vandermonde matrix. 

Compute the variance along the specified axis. 

Return the dot product of two vectors. 

Define a vectorized function with broadcasting. 

Split an array into multiple subarrays vertically (rowwise). 

Stack arrays in sequence vertically (row wise). 

Return elements chosen from x or y depending on condition. 

Return a new array of given shape and type, filled with zeros. 

Return an array of zeros with the same shape and type as a given array. 
jax.numpy.fft#

Compute the onedimensional discrete Fourier Transform. 

Compute the 2dimensional discrete Fourier Transform. 

Return the Discrete Fourier Transform sample frequencies. 

Compute the Ndimensional discrete Fourier Transform. 

Shift the zerofrequency component to the center of the spectrum. 

Compute the FFT of a signal that has Hermitian symmetry, i.e., a real 

Compute the onedimensional inverse discrete Fourier Transform. 

Compute the 2dimensional inverse discrete Fourier Transform. 

Compute the Ndimensional inverse discrete Fourier Transform. 

The inverse of fftshift. 

Compute the inverse FFT of a signal that has Hermitian symmetry. 

Computes the inverse of rfft. 

Computes the inverse of rfft2. 

Computes the inverse of rfftn. 

Compute the onedimensional discrete Fourier Transform for real input. 

Compute the 2dimensional FFT of a real array. 

Return the Discrete Fourier Transform sample frequencies 

Compute the Ndimensional discrete Fourier Transform for real input. 
jax.numpy.linalg#

Cholesky decomposition. 

Compute the condition number of a matrix. 

Compute the determinant of an array. 

Compute the eigenvalues and right eigenvectors of a square array. 

Return the eigenvalues and eigenvectors of a complex Hermitian 

Compute the eigenvalues of a general matrix. 

Compute the eigenvalues of a complex Hermitian or real symmetric matrix. 

Compute the (multiplicative) inverse of a matrix. 

Return the leastsquares solution to a linear matrix equation. 

Raise a square matrix to the (integer) power n. 

Return matrix rank of array using SVD method 

Compute the dot product of two or more arrays in a single function call, 

Matrix or vector norm. 

Compute the (MoorePenrose) pseudoinverse of a matrix. 

Compute the qr factorization of a matrix. 

Compute the sign and (natural) logarithm of the determinant of an array. 

Solve a linear matrix equation, or system of linear scalar equations. 

Singular Value Decomposition. 

Compute the 'inverse' of an Ndimensional array. 

Solve the tensor equation 
JAX DeviceArray#
The JAX DeviceArray
is the core array object in JAX: you can
think of it as the equivalent of a numpy.ndarray
backed by a memory buffer
on a single device. Like numpy.ndarray
, most users will not need to
instantiate DeviceArray
objects manually, but rather will create them via
jax.numpy
functions like array()
, arange()
,
linspace()
, and others listed above.
Copying and Serialization#
DeviceArray`
objects are designed to work seamlessly with Python
standard library tools where appropriate.
With the builtin copy
module, when copy.copy()
or copy.deepcopy()
encounder a DeviceArray
, it is equivalent to calling the
copy()
method, which will create a copy of
the buffer on the same device as the original array. This will work correctly within
traced/JITcompiled code, though copy operations may be elided by the compiler
in this context.
When the builtin pickle
module encounters a DeviceArray
,
it will be serialized via a compact bit representation in a similar manner to pickled
numpy.ndarray
objects. When unpickled, the result will be a new
DeviceArray
object on the default device.
This is because in general, pickling and unpickling may take place in different runtime
environments, and there is no general way to map the device IDs of one runtime
to the device IDs of another. If pickle
is used in traced/JITcompiled code,
it will result in a ConcretizationTypeError
.
Class Reference#
 jax.numpy.DeviceArray#
alias of
jaxlib.xla_extension.DeviceArrayBase
 class jaxlib.xla_extension.DeviceArrayBase#
 class jaxlib.xla_extension.DeviceArray#
 property T: jax.Array#
Reverse or permute the axes of an array; returns the modified array.
LAXbackend implementation of
numpy.transpose()
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
For an array a with two axes, transpose(a) gives the matrix transpose.
Refer to numpy.ndarray.transpose for full documentation.
 Parameters
a (array_like) – Input array.
axes (tuple or list of ints, optional) – If specified, it must be a tuple or list which contains a permutation of [0,1,..,N1] where N is the number of axes of a. The i’th axis of the returned array will correspond to the axis numbered
axes[i]
of the input. If not specified, defaults torange(a.ndim)[::1]
, which reverses the order of the axes.
 Returns
p – a with its axes permuted. A view is returned whenever possible.
 Return type
ndarray
 all(axis=None, out=None, keepdims=False, *, where=None)#
Test whether all array elements along a given axis evaluate to True.
LAXbackend implementation of
numpy.all()
.Original docstring below.
 Parameters
a (array_like) – Input array or object that can be converted to an array.
axis (None or int or tuple of ints, optional) – Axis or axes along which a logical AND reduction is performed. The default (
axis=None
) is to perform a logical AND over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.keepdims (bool, optional) –
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the all method of subclasses of ndarray, however any nondefault value will be. If the subclass’ method does not implement keepdims any exceptions will be raised.
where (array_like of bool, optional) – Elements to include in checking for all True values. See ~numpy.ufunc.reduce for details.
out (
None
) –
 Returns
all – A new boolean or array is returned unless out is specified, in which case a reference to out is returned.
 Return type
ndarray, bool
 any(axis=None, out=None, keepdims=False, *, where=None)#
Test whether any array element along a given axis evaluates to True.
LAXbackend implementation of
numpy.any()
.Original docstring below.
Returns single boolean if axis is
None
 Parameters
a (array_like) – Input array or object that can be converted to an array.
axis (None or int or tuple of ints, optional) – Axis or axes along which a logical OR reduction is performed. The default (
axis=None
) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.keepdims (bool, optional) –
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the any method of subclasses of ndarray, however any nondefault value will be. If the subclass’ method does not implement keepdims any exceptions will be raised.
where (array_like of bool, optional) – Elements to include in checking for any True values. See ~numpy.ufunc.reduce for details.
out (
None
) –
 Returns
any – A new boolean or ndarray is returned unless out is specified, in which case a reference to out is returned.
 Return type
bool or ndarray
 argmax(axis=None, out=None, keepdims=None)#
Returns the indices of the maximum values along an axis.
LAXbackend implementation of
numpy.argmax()
.Original docstring below.
 Parameters
a (array_like) – Input array.
axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.
keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.
 Returns
index_array – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. If keepdims is set to True, then the size of axis will be 1 with the resulting array having same shape as a.shape.
 Return type
ndarray of ints
 argmin(axis=None, out=None, keepdims=None)#
Returns the indices of the minimum values along an axis.
LAXbackend implementation of
numpy.argmin()
.Original docstring below.
 Parameters
a (array_like) – Input array.
axis (int, optional) – By default, the index is into the flattened array, otherwise along the specified axis.
keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the array.
 Returns
index_array – Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. If keepdims is set to True, then the size of axis will be 1 with the resulting array having same shape as a.shape.
 Return type
ndarray of ints
 argpartition(**kwargs)#
Perform an indirect partition along the given axis using the
LAXbackend implementation of
numpy.argpartition()
.* This function is not yet implemented by jax.numpy, and will raise NotImplementedError *
Original docstring below.
algorithm specified by the kind keyword. It returns an array of indices of the same shape as a that index data along the given axis in partitioned order.
New in version 1.8.0.
 Parameters
a (array_like) – Array to sort.
kth (int or sequence of ints) –
Element index to partition by. The kth element will be in its final sorted position and all smaller elements will be moved before it and all larger elements behind it. The order all elements in the partitions is undefined. If provided with a sequence of kth it will partition all of them into their sorted position at once.
Deprecated since version 1.22.0: Passing booleans as index is deprecated.
axis (int or None, optional) – Axis along which to sort. The default is 1 (the last axis). If None, the flattened array is used.
kind ({'introselect'}, optional) – Selection algorithm. Default is ‘introselect’
order (str or list of str, optional) – When a is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string, and not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties.
 Returns
index_array – Array of indices that partition a along the specified axis. If a is onedimensional,
a[index_array]
yields a partitioned a. More generally,np.take_along_axis(a, index_array, axis)
always yields the partitioned a, irrespective of dimensionality. Return type
ndarray, int
 argsort(axis= 1, kind='stable', order=None)#
Returns the indices that would sort an array.
LAXbackend implementation of
numpy.argsort()
.Only
kind='stable'
is supported. Otherkind
values will produce a warning and be treated as if they were'stable'
.Original docstring below.
Perform an indirect sort along the given axis using the algorithm specified by the kind keyword. It returns an array of indices of the same shape as a that index data along the given axis in sorted order.
 Parameters
a (array_like) – Array to sort.
axis (int or None, optional) – Axis along which to sort. The default is 1 (the last axis). If None, the flattened array is used.
kind ({'quicksort', 'mergesort', 'heapsort', 'stable'}, optional) –
Sorting algorithm. The default is ‘quicksort’. Note that both ‘stable’ and ‘mergesort’ use timsort under the covers and, in general, the actual implementation will vary with data type. The ‘mergesort’ option is retained for backwards compatibility.
Changed in version 1.15.0.: The ‘stable’ option was added.
order (str or list of str, optional) – When a is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string, and not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties.
 Returns
index_array – Array of indices that sort a along the specified axis. If a is onedimensional,
a[index_array]
yields a sorted a. More generally,np.take_along_axis(a, index_array, axis=axis)
always yields the sorted a, irrespective of dimensionality. Return type
ndarray, int
 astype(dtype)#
Copy the array and cast to a specified dtype.
This is implemeted via
jax.lax.convert_element_type()
, which may have slightly different behavior thannumpy.ndarray.astype()
in some cases. In particular, the details of floattoint and inttofloat casts are implementation dependent.
 property at#
Helper property for index update functionality.
The
at
property provides a functionally pure equivalent of inplace array modificatons.In particular:
Alternate syntax
Equivalent Inplace expression
x = x.at[idx].set(y)
x[idx] = y
x = x.at[idx].add(y)
x[idx] += y
x = x.at[idx].multiply(y)
x[idx] *= y
x = x.at[idx].divide(y)
x[idx] /= y
x = x.at[idx].power(y)
x[idx] **= y
x = x.at[idx].min(y)
x[idx] = minimum(x[idx], y)
x = x.at[idx].max(y)
x[idx] = maximum(x[idx], y)
x = x.at[idx].apply(ufunc)
ufunc.at(x, idx)
x = x.at[idx].get()
x = x[idx]
None of the
x.at
expressions modify the originalx
; instead they return a modified copy ofx
. However, inside ajit()
compiled function, expressions likex = x.at[idx].set(y)
are guaranteed to be applied inplace.Unlike NumPy inplace operations such as
x[idx] += y
, if multiple indices refer to the same location, all updates will be applied (NumPy would only apply the last update, rather than applying all updates.) The order in which conflicting updates are applied is implementationdefined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).By default, JAX assumes that all indices are inbounds. There is experimental support for giving more precise semantics to outofbounds indexed accesses, via the
mode
parameter (see below). Parameters
mode (str) –
Specify outofbound indexing mode. Options are:
"promise_in_bounds"
: (default) The user promises that indices are in bounds. No additional checking will be performed. In practice, this means that outofbounds indices inget()
will be clipped, and outofbounds indices inset()
,add()
, etc. will be dropped."clip"
: clamp out of bounds indices into valid range."drop"
: ignore outofbound indices."fill"
: alias for"drop"
. For get(), the optionalfill_value
argument specifies the value that will be returned.See
jax.lax.GatherScatterMode
for more details.
indices_are_sorted (bool) – If True, the implementation will assume that the indices passed to
at[]
are sorted in ascending order, which can lead to more efficient execution on some backends.unique_indices (bool) – If True, the implementation will assume that the indices passed to
at[]
are unique, which can result in more efficient execution on some backends.fill_value (Any) – Only applies to the
get()
method: the fill value to return for outofbounds slices when mode is'fill'
. Ignored otherwise. Defaults toNaN
for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, andTrue
for booleans.
Examples
>>> x = jnp.arange(5.0) >>> x Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[2].add(10) Array([ 0., 1., 12., 3., 4.], dtype=float32) >>> x.at[10].add(10) # outofbounds indices are ignored Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[20].add(10, mode='clip') Array([ 0., 1., 2., 3., 14.], dtype=float32) >>> x.at[2].get() Array(2., dtype=float32) >>> x.at[20].get() # outofbounds indices clipped Array(4., dtype=float32) >>> x.at[20].get(mode='fill') # outofbounds indices filled with NaN Array(nan, dtype=float32) >>> x.at[20].get(mode='fill', fill_value=1) # custom fill value Array(1., dtype=float32)
 block_host_until_ready()#
(self: xla::PyBuffer::pyobject) > Status
 block_until_ready()#
(self: xla::PyBuffer::pyobject) > StatusOr[xla::PyBuffer::pyobject]
 broadcast(sizes)#
Broadcasts an array, adding new leading dimensions
 Parameters
 Return type
Array
 Returns
An array containing the result.
See also
jax.lax.broadcast_in_dim : add new dimensions at any location in the array shape.
 broadcast_in_dim(shape, broadcast_dimensions)#
Wraps XLA’s BroadcastInDim operator.
 Parameters
 Return type
Array
 Returns
An array containing the result.
See also
jax.lax.broadcast : simpler interface to add new leading dimensions.
 choose(choices, out=None, mode='raise')#
Construct an array from an index array and a list of arrays to choose from.
LAXbackend implementation of
numpy.choose()
.Original docstring below.
First of all, if confused or uncertain, definitely look at the Examples  in its full generality, this function is less simple than it might seem from the following code description (below ndi = numpy.lib.index_tricks):
np.choose(a,c) == np.array([c[a[I]][I] for I in ndi.ndindex(a.shape)])
.But this omits some subtleties. Here is a fully general summary:
Given an “index” array (a) of integers and a sequence of
n
arrays (choices), a and each choice array are first broadcast, as necessary, to arrays of a common shape; calling these Ba and Bchoices[i], i = 0,…,n1 we have that, necessarily,Ba.shape == Bchoices[i].shape
for eachi
. Then, a new array with shapeBa.shape
is created as follows:if
mode='raise'
(the default), then, first of all, each element ofa
(and thusBa
) must be in the range[0, n1]
; now, suppose thati
(in that range) is the value at the(j0, j1, ..., jm)
position inBa
 then the value at the same position in the new array is the value inBchoices[i]
at that same position;if
mode='wrap'
, values in a (and thus Ba) may be any (signed) integer; modular arithmetic is used to map integers outside the range [0, n1] back into that range; and then the new array is constructed as above;if
mode='clip'
, values in a (and thusBa
) may be any (signed) integer; negative integers are mapped to 0; values greater thann1
are mapped ton1
; and then the new array is constructed as above.
 Parameters
a (int array) – This array must contain integers in
[0, n1]
, wheren
is the number of choices, unlessmode=wrap
ormode=clip
, in which cases any integers are permissible.choices (sequence of arrays) – Choice arrays. a and all of the choices must be broadcastable to the same shape. If choices is itself an array (not recommended), then its outermost dimension (i.e., the one corresponding to
choices.shape[0]
) is taken as defining the “sequence”.mode ({'raise' (default), 'wrap', 'clip'}, optional) –
Specifies how indices outside
[0, n1]
will be treated:’raise’ : an exception is raised
’wrap’ : value becomes value mod
n
’clip’ : values < 0 are mapped to 0, values > n1 are mapped to n1
out (
None
) –
 Returns
merged_array – The merged result.
 Return type
array
 clone()#
(self: xla::PyBuffer::pyobject) > xla::PyBuffer::pyobject
 conj()#
Return the complex conjugate, elementwise.
LAXbackend implementation of
numpy.conjugate()
.Original docstring below.
The complex conjugate of a complex number is obtained by changing the sign of its imaginary part.
 Parameters
x (array_like) – Input value.
 Returns
y – The complex conjugate of x, with same dtype as y. This is a scalar if x is a scalar.
 Return type
ndarray
 conjugate()#
Return the complex conjugate, elementwise.
LAXbackend implementation of
numpy.conjugate()
.Original docstring below.
The complex conjugate of a complex number is obtained by changing the sign of its imaginary part.
 Parameters
x (array_like) – Input value.
 Returns
y – The complex conjugate of x, with same dtype as y. This is a scalar if x is a scalar.
 Return type
ndarray
 copy(order=None)#
Return an array copy of the given object.
LAXbackend implementation of
numpy.copy()
.This function will create arrays on JAX’s default device. For control of the device placement of data, see
jax.device_put()
. More information is available in the JAX FAQ at Controlling data and computation placement on devices (full FAQ at https://jax.readthedocs.io/en/latest/faq.html).Original docstring below.
 Parameters
a (array_like) – Input data.
order ({'C', 'F', 'A', 'K'}, optional) – Controls the memory layout of the copy. ‘C’ means Corder, ‘F’ means Forder, ‘A’ means ‘F’ if a is Fortran contiguous, ‘C’ otherwise. ‘K’ means match the layout of a as closely as possible. (Note that this function and
ndarray.copy()
are very similar, but have different default values for their order= arguments.)
 Returns
arr – Array interpretation of a.
 Return type
ndarray
 copy_to_device()#
(self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) > StatusOr[object]
 copy_to_host_async()#
(self: xla::PyBuffer::pyobject) > Status
 copy_to_remote_device()#
(self: xla::PyBuffer::pyobject, arg0: bytes) > Tuple[Status, bool]
 cumprod(axis=None, dtype=None, out=None)#
Return the cumulative product of elements along a given axis.
LAXbackend implementation of
numpy.cumprod()
.Original docstring below.
 Parameters
a (array_like) – Input array.
axis (int, optional) – Axis along which the cumulative product is computed. By default the input is flattened.
dtype (dtype, optional) – Type of the returned array, as well as of the accumulator in which the elements are multiplied. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used instead.
out (
None
) –
 Returns
cumprod – A new array holding the result is returned unless out is specified, in which case a reference to out is returned.
 Return type
ndarray
 cumsum(axis=None, dtype=None, out=None)#
Return the cumulative sum of the elements along a given axis.
LAXbackend implementation of
numpy.cumsum()
.Original docstring below.
 Parameters
a (array_like) – Input array.
axis (int, optional) – Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.
dtype (dtype, optional) – Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.
out (
None
) –
 Returns
cumsum_along_axis – A new array holding the result is returned unless out is specified, in which case a reference to out is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1d array.
 Return type
ndarray.
 delete()#
(self: xla::PyBuffer::pyobject) > None
 device()#
(self: xla::PyBuffer::pyobject) > jaxlib.xla_extension.Device
 diagonal(offset=0, axis1=0, axis2=1)#
Return specified diagonals.
LAXbackend implementation of
numpy.diagonal()
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
If a is 2D, returns the diagonal of a with the given offset, i.e., the collection of elements of the form
a[i, i+offset]
. If a has more than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2D subarray whose diagonal is returned. The shape of the resulting array can be determined by removing axis1 and axis2 and appending an index to the right equal to the size of the resulting diagonals.In versions of NumPy prior to 1.7, this function always returned a new, independent array containing a copy of the values in the diagonal.
In NumPy 1.7 and 1.8, it continues to return a copy of the diagonal, but depending on this fact is deprecated. Writing to the resulting array continues to work as it used to, but a FutureWarning is issued.
Starting in NumPy 1.9 it returns a readonly view on the original array. Attempting to write to the resulting array will produce an error.
In some future release, it will return a read/write view and writing to the returned array will alter your original array. The returned array will have the same type as the input array.
If you don’t write to the array returned by this function, then you can just ignore all of the above.
If you depend on the current behavior, then we suggest copying the returned array explicitly, i.e., use
np.diagonal(a).copy()
instead of justnp.diagonal(a)
. This will work with both past and future versions of NumPy. Parameters
a (array_like) – Array from which the diagonals are taken.
offset (int, optional) – Offset of the diagonal from the main diagonal. Can be positive or negative. Defaults to main diagonal (0).
axis1 (int, optional) – Axis to be used as the first axis of the 2D subarrays from which the diagonals should be taken. Defaults to first axis (0).
axis2 (int, optional) – Axis to be used as the second axis of the 2D subarrays from which the diagonals should be taken. Defaults to second axis (1).
 Returns
array_of_diagonals – If a is 2D, then a 1D array containing the diagonal and of the same type as a is returned unless a is a matrix, in which case a 1D array rather than a (2D) matrix is returned in order to maintain backward compatibility.
If
a.ndim > 2
, then the dimensions specified by axis1 and axis2 are removed, and a new axis inserted at the end corresponding to the diagonal. Return type
ndarray
 dot(b, *, precision=None)#
Dot product of two arrays. Specifically,
LAXbackend implementation of
numpy.dot()
.In addition to the original NumPy arguments listed below, also supports
precision
for extra control over matrixmultiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twoPrecision
enums indicating separate precision for each argument.Original docstring below.
If both a and b are 1D arrays, it is inner product of vectors (without complex conjugation).
If both a and b are 2D arrays, it is matrix multiplication, but using
matmul()
ora @ b
is preferred.If either a or b is 0D (scalar), it is equivalent to
multiply()
and usingnumpy.multiply(a, b)
ora * b
is preferred.If a is an ND array and b is a 1D array, it is a sum product over the last axis of a and b.
If a is an ND array and b is an MD array (where
M>=2
), it is a sum product over the last axis of a and the secondtolast axis of b:dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
 Parameters
a (array_like) – First argument.
b (array_like) – Second argument.
 Returns
output – Returns the dot product of a and b. If a and b are both scalars or both 1D arrays then a scalar is returned; otherwise an array is returned. If out is given, then it is returned.
 Return type
ndarray
 flatten(order='C')#
Return a contiguous flattened array.
LAXbackend implementation of
numpy.ravel()
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
A 1D array, containing the elements of the input, is returned. A copy is made only if needed.
As of NumPy 1.10, the returned array will have the same type as the input array. (for example, a masked array will be returned for a masked array input)
 Parameters
a (array_like) – Input array. The elements in a are read in the order specified by order, and packed as a 1D array.
order ({'C','F', 'A', 'K'}, optional) – The elements of a are read using this index order. ‘C’ means to index the elements in rowmajor, Cstyle order, with the last axis index changing fastest, back to the first axis index changing slowest. ‘F’ means to index the elements in columnmajor, Fortranstyle order, with the first index changing fastest, and the last index changing slowest. Note that the ‘C’ and ‘F’ options take no account of the memory layout of the underlying array, and only refer to the order of axis indexing. ‘A’ means to read the elements in Fortranlike index order if a is Fortran contiguous in memory, Clike order otherwise. ‘K’ means to read the elements in the order they occur in memory, except for reversing the data when strides are negative. By default, ‘C’ index order is used.
 Returns
y – y is an array of the same subtype as a, with shape
(a.size,)
. Note that matrices are special cased for backward compatibility, if a is a matrix, then y is a 1D ndarray. Return type
array_like
 property imag: jax.Array#
Return the imaginary part of the complex argument.
LAXbackend implementation of
numpy.imag()
.Original docstring below.
 Parameters
val (array_like) – Input array.
 Returns
out – The imaginary component of the complex argument. If val is real, the type of val is used for the output. If val has complex elements, the returned type is float.
 Return type
ndarray or scalar
 is_deleted()#
(self: xla::PyBuffer::pyobject) > bool
 is_known_ready()#
(self: xla::PyBuffer::pyobject) > StatusOr[bool]
 is_ready()#
(self: xla::PyBuffer::pyobject) > StatusOr[bool]
 max(axis=None, out=None, keepdims=False, initial=None, where=None)#
Return the maximum of an array or maximum along an axis.
LAXbackend implementation of
numpy.amax()
.Original docstring below.
 Parameters
a (array_like) – Input data.
axis (None or int or tuple of ints, optional) – Axis or axes along which to operate. By default, flattened input is used.
keepdims (bool, optional) –
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the amax method of subclasses of ndarray, however any nondefault value will be. If the subclass’ method does not implement keepdims any exceptions will be raised.
initial (scalar, optional) – The minimum value of an output element. Must be present to allow computation on empty slice. See ~numpy.ufunc.reduce for details.
where (array_like of bool, optional) – Elements to compare for the maximum. See ~numpy.ufunc.reduce for details.
out (
None
) –
 Returns
amax – Maximum of a. If axis is None, the result is a scalar value. If axis is given, the result is an array of dimension
a.ndim  1
. Return type
ndarray or scalar
 mean(axis=None, dtype=None, out=None, keepdims=False, *, where=None)#
Compute the arithmetic mean along the specified axis.
LAXbackend implementation of
numpy.mean()
.Original docstring below.
Returns the average of the array elements. The average is taken over the flattened array by default, otherwise over the specified axis. float64 intermediate and return values are used for integer inputs.
 Parameters
a (array_like) – Array containing numbers whose mean is desired. If a is not an array, a conversion is attempted.
axis (None or int or tuple of ints, optional) – Axis or axes along which the means are computed. The default is to compute the mean of the flattened array.
dtype (datatype, optional) – Type to use in computing the mean. For integer inputs, the default is float64; for floating point inputs, it is the same as the input dtype.
keepdims (bool, optional) –
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the mean method of subclasses of ndarray, however any nondefault value will be. If the subclass’ method does not implement keepdims any exceptions will be raised.
where (array_like of bool, optional) – Elements to include in the mean. See ~numpy.ufunc.reduce for details.
out (
None
) –
 Returns
m – If out=None, returns a new array containing the mean values, otherwise a reference to the output array is returned.
 Return type
ndarray, see dtype parameter above
 min(axis=None, out=None, keepdims=False, initial=None, where=None)#
Return the minimum of an array or minimum along an axis.
LAXbackend implementation of
numpy.amin()
.Original docstring below.
 Parameters
a (array_like) – Input data.
axis (None or int or tuple of ints, optional) – Axis or axes along which to operate. By default, flattened input is used.
keepdims (bool, optional) –
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the amin method of subclasses of ndarray, however any nondefault value will be. If the subclass’ method does not implement keepdims any exceptions will be raised.
initial (scalar, optional) – The maximum value of an output element. Must be present to allow computation on empty slice. See ~numpy.ufunc.reduce for details.
where (array_like of bool, optional) – Elements to compare for the minimum. See ~numpy.ufunc.reduce for details.
out (
None
) –
 Returns
amin – Minimum of a. If axis is None, the result is a scalar value. If axis is given, the result is an array of dimension
a.ndim  1
. Return type
ndarray or scalar
 nonzero(*, size=None, fill_value=None)#
Return the indices of the elements that are nonzero.
LAXbackend implementation of
numpy.nonzero()
.Because the size of the output of
nonzero
is datadependent, the function is not typically compatible with JIT. The JAX version adds the optionalsize
argument which must be specified statically forjnp.nonzero
to be used within some of JAX’s transformations.Original docstring below.
Returns a tuple of arrays, one for each dimension of a, containing the indices of the nonzero elements in that dimension. The values in a are always tested and returned in rowmajor, Cstyle order.
To group the indices by element, rather than dimension, use argwhere, which returns a row for each nonzero element.
Note
When called on a zerod array or scalar,
nonzero(a)
is treated asnonzero(atleast_1d(a))
.Deprecated since version 1.17.0: Use atleast_1d explicitly if this behavior is deliberate.
 Parameters
a (array_like) – Input array.
size (int, optional) – If specified, the indices of the first
size
True elements will be returned. If there are fewer unique elements thansize
indicates, the return value will be padded withfill_value
.fill_value (array_like, optional) – When
size
is specified and there are fewer than the indicated number of elements, the remaining elements will be filled withfill_value
, which defaults to zero.
 Returns
tuple_of_arrays – Indices of elements that are nonzero.
 Return type
 on_device_size_in_bytes()#
(self: xla::PyBuffer::pyobject) > StatusOr[int]
 platform()#
(self: xla::PyBuffer::pyobject) > str
 prod(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)#
Return the product of array elements over a given axis.
LAXbackend implementation of
numpy.prod()
.Original docstring below.
 Parameters
a (array_like) – Input data.
axis (None or int or tuple of ints, optional) – Axis or axes along which a product is performed. The default, axis=None, will calculate the product of all the elements in the input array. If axis is negative it counts from the last to the first axis.
dtype (dtype, optional) – The type of the returned array, as well as of the accumulator in which the elements are multiplied. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
keepdims (bool, optional) –
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the prod method of subclasses of ndarray, however any nondefault value will be. If the subclass’ method does not implement keepdims any exceptions will be raised.
initial (scalar, optional) – The starting value for this product. See ~numpy.ufunc.reduce for details.
where (array_like of bool, optional) – Elements to include in the product. See ~numpy.ufunc.reduce for details.
promote_integers (bool, default=True) – If True, then integer inputs will be promoted to the widest available integer dtype, following numpy’s behavior. If False, the result will have the same dtype as the input.
promote_integers
is ignored ifdtype
is specified.out (
None
) –
 Returns
product_along_axis – An array shaped as a but with the specified axis removed. Returns a reference to out if specified.
 Return type
ndarray, see dtype parameter above.
 ptp(axis=None, out=None, keepdims=False)#
Range of values (maximum  minimum) along an axis.
LAXbackend implementation of
numpy.ptp()
.Original docstring below.
The name of the function comes from the acronym for ‘peak to peak’.
Warning
ptp preserves the data type of the array. This means the return value for an input of signed integers with n bits (e.g. np.int8, np.int16, etc) is also a signed integer with n bits. In that case, peaktopeak values greater than
2**(n1)1
will be returned as negative values. An example with a workaround is shown below. Parameters
a (array_like) – Input values.
axis (None or int or tuple of ints, optional) – Axis along which to find the peaks. By default, flatten the array. axis may be negative, in which case it counts from the last to the first axis.
keepdims (bool, optional) –
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the ptp method of subclasses of ndarray, however any nondefault value will be. If the subclass’ method does not implement keepdims any exceptions will be raised.
out (
None
) –
 Returns
ptp – A new array holding the result, unless out was specified, in which case a reference to out is returned.
 Return type
ndarray
 ravel(order='C')#
Return a contiguous flattened array.
LAXbackend implementation of
numpy.ravel()
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
A 1D array, containing the elements of the input, is returned. A copy is made only if needed.
As of NumPy 1.10, the returned array will have the same type as the input array. (for example, a masked array will be returned for a masked array input)
 Parameters
a (array_like) – Input array. The elements in a are read in the order specified by order, and packed as a 1D array.
order ({'C','F', 'A', 'K'}, optional) – The elements of a are read using this index order. ‘C’ means to index the elements in rowmajor, Cstyle order, with the last axis index changing fastest, back to the first axis index changing slowest. ‘F’ means to index the elements in columnmajor, Fortranstyle order, with the first index changing fastest, and the last index changing slowest. Note that the ‘C’ and ‘F’ options take no account of the memory layout of the underlying array, and only refer to the order of axis indexing. ‘A’ means to read the elements in Fortranlike index order if a is Fortran contiguous in memory, Clike order otherwise. ‘K’ means to read the elements in the order they occur in memory, except for reversing the data when strides are negative. By default, ‘C’ index order is used.
 Returns
y – y is an array of the same subtype as a, with shape
(a.size,)
. Note that matrices are special cased for backward compatibility, if a is a matrix, then y is a 1D ndarray. Return type
array_like
 property real: jax.Array#
Return the real part of the complex argument.
LAXbackend implementation of
numpy.real()
.Original docstring below.
 Parameters
val (array_like) – Input array.
 Returns
out – The real component of the complex argument. If val is real, the type of val is used for the output. If val has complex elements, the returned type is float.
 Return type
ndarray or scalar
 repeat(repeats, axis=None, *, total_repeat_length=None)#
Repeat elements of an array.
LAXbackend implementation of
numpy.repeat()
.JAX adds the optional total_repeat_length parameter which specifies the total number of repeat, and defaults to sum(repeats). It must be specified for repeat to be compilable. If sum(repeats) is larger than the specified total_repeat_length the remaining values will be discarded. In the case of sum(repeats) being smaller than the specified target length, the final value will be repeated.
Original docstring below.
 Parameters
a (array_like) – Input array.
repeats (int or array of ints) – The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
axis (int, optional) – The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
 Returns
repeated_array – Output array which has the same shape as a, except along the given axis.
 Return type
ndarray
 round(decimals=0, out=None)#
Evenly round to the given number of decimals.
LAXbackend implementation of
numpy.around()
.Original docstring below.
 Parameters
a (array_like) – Input data.
decimals (int, optional) – Number of decimal places to round to (default: 0). If decimals is negative, it specifies the number of positions to the left of the decimal point.
 Returns
rounded_array – An array of the same type as a, containing the rounded values. Unless out was specified, a new array is created. A reference to the result is returned.
The real and imaginary parts of complex numbers are rounded separately. The result of rounding a float is a float.
 Return type
ndarray
References
 1
“Lecture Notes on the Status of IEEE 754”, William Kahan, https://people.eecs.berkeley.edu/~wkahan/ieee754status/IEEE754.PDF
 Parameters
out (
None
) –
 searchsorted(v, side='left', sorter=None, *, method='scan')#
Find indices where elements should be inserted to maintain order.
LAXbackend implementation of
numpy.searchsorted()
.Original docstring below.
Find the indices into a sorted array a such that, if the corresponding elements in v were inserted before the indices, the order of a would be preserved.
Assuming that a is sorted:
side
returned index i satisfies
left
a[i1] < v <= a[i]
right
a[i1] <= v < a[i]
 Parameters
a (1D array_like) – Input array. If sorter is None, then it must be sorted in ascending order, otherwise sorter must be an array of indices that sort it.
v (array_like) – Values to insert into a.
side ({'left', 'right'}, optional) – If ‘left’, the index of the first suitable location found is given. If ‘right’, return the last such index. If there is no suitable index, return either 0 or N (where N is the length of a).
method (str) – One of ‘scan’ (default) or ‘sort’. Controls the method used by the implementation; ‘scan’ tends to be more performant on CPU (particularly when
a
is very large), while ‘sort’ is often more performant on accelerator backends like GPU and TPU (particularly whenv
is very large).sorter (
None
) –
 Returns
indices – Array of insertion points with the same shape as v, or an integer if v is a scalar.
 Return type
int or array of ints
 sort(axis= 1, kind='quicksort', order=None)#
Return a sorted copy of an array.
LAXbackend implementation of
numpy.sort()
.Original docstring below.
 Parameters
a (array_like) – Array to be sorted.
axis (int or None, optional) – Axis along which to sort. If None, the array is flattened before sorting. The default is 1, which sorts along the last axis.
kind ({'quicksort', 'mergesort', 'heapsort', 'stable'}, optional) –
Sorting algorithm. The default is ‘quicksort’. Note that both ‘stable’ and ‘mergesort’ use timsort or radix sort under the covers and, in general, the actual implementation will vary with data type. The ‘mergesort’ option is retained for backwards compatibility.
Changed in version 1.15.0.: The ‘stable’ option was added.
order (str or list of str, optional) – When a is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string, and not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties.
 Returns
sorted_array – Array of the same type and shape as a.
 Return type
ndarray
 split(indices_or_sections, axis=0)#
Split an array into multiple subarrays as views into ary.
LAXbackend implementation of
numpy.split()
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
 Parameters
ary (ndarray) – Array to be divided into subarrays.
indices_or_sections (int or 1D array) –
If indices_or_sections is an integer, N, the array will be divided into N equal arrays along axis. If such a split is not possible, an error is raised.
If indices_or_sections is a 1D array of sorted integers, the entries indicate where along axis the array is split. For example,
[2, 3]
would, foraxis=0
, result inary[:2]
ary[2:3]
ary[3:]
If an index exceeds the dimension of the array along axis, an empty subarray is returned correspondingly.
axis (int, optional) – The axis along which to split, default is 0.
 Returns
subarrays – A list of subarrays as views into ary.
 Return type
list of ndarrays
 squeeze(axis=None)#
Remove axes of length one from a.
LAXbackend implementation of
numpy.squeeze()
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
 Parameters
a (array_like) – Input data.
axis (None or int or tuple of ints, optional) –
 Returns
squeezed – The input array, but with all or a subset of the dimensions of length 1 removed. This is always a itself or a view into a. Note that if all axes are squeezed, the result is a 0d array and not a scalar.
 Return type
ndarray
 std(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None)#
Compute the standard deviation along the specified axis.
LAXbackend implementation of
numpy.std()
.Original docstring below.
Returns the standard deviation, a measure of the spread of a distribution, of the array elements. The standard deviation is computed for the flattened array by default, otherwise over the specified axis.
 Parameters
a (array_like) – Calculate the standard deviation of these values.
axis (None or int or tuple of ints, optional) – Axis or axes along which the standard deviation is computed. The default is to compute the standard deviation of the flattened array.
dtype (dtype, optional) – Type to use in computing the standard deviation. For arrays of integer type the default is float64, for arrays of float types it is the same as the array type.
ddof (int, optional) – Means Delta Degrees of Freedom. The divisor used in calculations is
N  ddof
, whereN
represents the number of elements. By default ddof is zero.keepdims (bool, optional) –
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the std method of subclasses of ndarray, however any nondefault value will be. If the subclass’ method does not implement keepdims any exceptions will be raised.
where (array_like of bool, optional) – Elements to include in the standard deviation. See ~numpy.ufunc.reduce for details.
out (
None
) –
 Returns
standard_deviation – If out is None, return a new array containing the standard deviation, otherwise return a reference to the output array.
 Return type
ndarray, see dtype parameter above.
 sum(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)#
Sum of array elements over a given axis.
LAXbackend implementation of
numpy.sum()
.Original docstring below.
 Parameters
a (array_like) – Elements to sum.
axis (None or int or tuple of ints, optional) – Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
dtype (dtype, optional) – The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
keepdims (bool, optional) –
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the sum method of subclasses of ndarray, however any nondefault value will be. If the subclass’ method does not implement keepdims any exceptions will be raised.
initial (scalar, optional) – Starting value for the sum. See ~numpy.ufunc.reduce for details.
where (array_like of bool, optional) – Elements to include in the sum. See ~numpy.ufunc.reduce for details.
promote_integers (bool, default=True) – If True, then integer inputs will be promoted to the widest available integer dtype, following numpy’s behavior. If False, the result will have the same dtype as the input.
promote_integers
is ignored ifdtype
is specified.out (
None
) –
 Returns
sum_along_axis – An array with the same shape as a, with the specified axis removed. If a is a 0d array, or if axis is None, a scalar is returned. If an output array is specified, a reference to out is returned.
 Return type
ndarray
 swapaxes(axis1, axis2)#
Interchange two axes of an array.
LAXbackend implementation of
numpy.swapaxes()
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
 Parameters
 Returns
a_swapped – For NumPy >= 1.10.0, if a is an ndarray, then a view of a is returned; otherwise a new array is created. For earlier NumPy versions a view of a is returned only if the order of the axes is changed, otherwise the input array is returned.
 Return type
ndarray
 take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)#
Take elements from an array along an axis.
LAXbackend implementation of
numpy.take()
.The JAX version adds several extra parameters, described below, which are forwarded to
jax.lax.gather()
for finer control over indexing.Original docstring below.
When axis is not None, this function does the same thing as “fancy” indexing (indexing arrays using arrays); however, it can be easier to use if you need elements along a given axis. A call such as
np.take(arr, indices, axis=3)
is equivalent toarr[:,:,:,indices,...]
.Explained without fancy indexing, this is equivalent to the following use of ndindex, which sets each of
ii
,jj
, andkk
to a tuple of indices:Ni, Nk = a.shape[:axis], a.shape[axis+1:] Nj = indices.shape for ii in ndindex(Ni): for jj in ndindex(Nj): for kk in ndindex(Nk): out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
 Parameters
a (array_like (Ni..., M, Nk...)) – The source array.
indices (array_like (Nj...)) – The indices of the values to extract.
axis (int, optional) – The axis over which to select values. By default, the flattened input array is used.
mode (string, default="fill") – Outofbounds indexing mode. The default mode=”fill” returns invalid values (e.g. NaN) for outof bounds indices. See
jax.numpy.ndarray.at
for more discussion of outofbounds indexing in JAX.unique_indices (bool, default=False) – If True, the implementation will assume that the indices are unique, which can result in more efficient execution on some backends.
indices_are_sorted (bool, default=False) – If True, the implementation will assume that the indices are sorted in ascending order, which can lead to more efficient execution on some backends.
fill_value (optional) – The fill value to return for outofbounds slices when mode is ‘fill’. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.
 Returns
out – The returned array has the same type as a.
 Return type
ndarray (Ni…, Nj…, Nk…)
 trace(offset=0, axis1=0, axis2=1, dtype=None, out=None)#
Return the sum along diagonals of the array.
LAXbackend implementation of
numpy.trace()
.Original docstring below.
If a is 2D, the sum along its diagonal with the given offset is returned, i.e., the sum of elements
a[i,i+offset]
for all i.If a has more than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2D subarrays whose traces are returned. The shape of the resulting array is the same as that of a with axis1 and axis2 removed.
 Parameters
a (array_like) – Input array, from which the diagonals are taken.
offset (int, optional) – Offset of the diagonal from the main diagonal. Can be both positive and negative. Defaults to 0.
axis1 (int, optional) – Axes to be used as the first and second axis of the 2D subarrays from which the diagonals should be taken. Defaults are the first two axes of a.
axis2 (int, optional) – Axes to be used as the first and second axis of the 2D subarrays from which the diagonals should be taken. Defaults are the first two axes of a.
dtype (dtype, optional) – Determines the datatype of the returned array and of the accumulator where the elements are summed. If dtype has the value None and a is of integer type of precision less than the default integer precision, then the default integer precision is used. Otherwise, the precision is the same as that of a.
out (
None
) –
 Returns
sum_along_diagonals – If a is 2D, the sum along the diagonal is returned. If a has larger dimensions, then an array of sums along diagonals is returned.
 Return type
ndarray
 unsafe_buffer_pointer()#
(self: xla::PyBuffer::pyobject) > StatusOr[int]
 var(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None)#
Compute the variance along the specified axis.
LAXbackend implementation of
numpy.var()
.Original docstring below.
Returns the variance of the array elements, a measure of the spread of a distribution. The variance is computed for the flattened array by default, otherwise over the specified axis.
 Parameters
a (array_like) – Array containing numbers whose variance is desired. If a is not an array, a conversion is attempted.
axis (None or int or tuple of ints, optional) – Axis or axes along which the variance is computed. The default is to compute the variance of the flattened array.
dtype (datatype, optional) – Type to use in computing the variance. For arrays of integer type the default is float64; for arrays of float types it is the same as the array type.
ddof (int, optional) – “Delta Degrees of Freedom”: the divisor used in the calculation is
N  ddof
, whereN
represents the number of elements. By default ddof is zero.keepdims (bool, optional) –
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the var method of subclasses of ndarray, however any nondefault value will be. If the subclass’ method does not implement keepdims any exceptions will be raised.
where (array_like of bool, optional) – Elements to include in the variance. See ~numpy.ufunc.reduce for details.
out (
None
) –
 Returns
variance – If
out=None
, returns a new array containing the variance; otherwise, a reference to the output array is returned. Return type
ndarray, see dtype parameter above
 xla_dynamic_shape()#
(self: xla::PyBuffer::pyobject) > StatusOr[jaxlib.xla_extension.Shape]
 xla_shape()#
(self: xla::PyBuffer::pyobject) > jaxlib.xla_extension.Shape