jax.numpy packageÂ¶
Implements the NumPy API, using the primitives in jax.lax
.
While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.
Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays inplace cannot be implemented in JAX. However, often JAX is able to provide a alternative API that is purely functional. For example, instead of inplace array updates (
x[i] = y
), JAX provides an alternative pure indexed update functionjax.ops.index_update()
.Relatedly, some NumPy functions return views of arrays when possible (examples are
numpy.transpose()
andnumpy.reshape()
). JAX versions of such functions will return copies instead, although such copies can often be optimized away by XLA when sequences of operations are compiled usingjax.jit()
.NumPy is very aggressive at promoting values to
float64
type. JAX sometimes is less aggressive about type promotion (See Type promotion semantics).
A small number of NumPy operations that have datadependent output shapes are
incompatible with jax.jit()
compilation. The XLA compiler requires that
shapes of arrays be known at compile time. While it would be possible to provide
a JAX implementation of an API such as numpy.nonzero()
, we would be unable
to JITcompile it because the shape of its output depends on the contents of the
input data.
Not every function in NumPy is implemented; contributions are welcome!

Calculate the absolute value elementwise. 

Calculate the absolute value elementwise. 

Add arguments elementwise. 

Test whether all array elements along a given axis evaluate to True. 

Returns True if two arrays are elementwise equal within a tolerance. 

Test whether all array elements along a given axis evaluate to True. 

Return the maximum of an array or maximum along an axis. 

Return the minimum of an array or minimum along an axis. 

Return the angle of the complex argument. 

Test whether any array element along a given axis evaluates to True. 

Append values to the end of an array. 

Apply a function to 1D slices along the given axis. 

Apply a function repeatedly over multiple axes. 

Return evenly spaced values within a given interval. 

Trigonometric inverse cosine, elementwise. 

Inverse hyperbolic cosine, elementwise. 

Inverse sine, elementwise. 

Inverse hyperbolic sine elementwise. 

Trigonometric inverse tangent, elementwise. 

Elementwise arc tangent of 

Inverse hyperbolic tangent elementwise. 

Returns the indices of the maximum values along an axis. 

Returns the indices of the minimum values along an axis. 

Returns the indices that would sort an array. 

Find the indices of array elements that are nonzero, grouped by element. 

Evenly round to the given number of decimals. 

Create an array. 

True if two arrays have the same shape and elements, False otherwise. 

Returns True if input arrays are shape consistent and all elements equal. 

Return the string representation of an array. 

Split an array into multiple subarrays. 

Return a string representation of the data in an array. 

Convert the input to an array. 

Convert inputs to arrays with at least one dimension. 

View inputs as arrays with at least two dimensions. 

View inputs as arrays with at least three dimensions. 

Compute the weighted average along the specified axis. 

Return the Bartlett window. 

Count number of occurrences of each value in array of nonnegative ints. 

Compute the bitwise AND of two arrays elementwise. 

Compute bitwise inversion, or bitwise NOT, elementwise. 

Compute the bitwise OR of two arrays elementwise. 

Compute the bitwise XOR of two arrays elementwise. 

Return the Blackman window. 

Assemble an ndarray from nested lists of blocks. 



Like Numpyâs broadcast_arrays but doesnât return views. 



Broadcast an array to a new shape. 
Concatenate slices, scalars and arraylike objects along the last axis. 


Returns True if cast between data types can occur according to the casting rule. 

Return the cuberoot of an array, elementwise. 
alias of 


Return the ceiling of the input, elementwise. 
Abstract base class of all character string scalar types. 


Construct an array from an index array and a set of arrays to choose from. 

Clip (limit) the values in an array. 

Stack 1D arrays as columns into a 2D array. 
alias of 





Abstract base class of all complex number scalar types that are made up of floatingpoint numbers. 

The warning raised when casting a complex dtype to a real dtype. 


Return selected slices of an array along given axis. 

Join a sequence of arrays along an existing axis. 

Return the complex conjugate, elementwise. 

Return the complex conjugate, elementwise. 

Returns the discrete, linear convolution of two onedimensional sequences. 

Change the sign of x1 to that of x2, elementwise. 

Return Pearson productmoment correlation coefficients. 

Crosscorrelation of two 1dimensional sequences. 

Cosine elementwise. 

Hyperbolic cosine, elementwise. 

Counts the number of nonzero values in the array 

Estimate a covariance matrix, given data and weights. 

Return the cross product of two (arrays of) vectors. 
alias of 


Return the cumulative product of elements along a given axis. 

Return the cumulative product of elements along a given axis. 

Return the cumulative sum of the elements along a given axis. 

Convert angles from degrees to radians. 

Convert angles from radians to degrees. 

Return a new array with subarrays along an axis deleted. 

Extract a diagonal or construct a diagonal array. 

Create a twodimensional array with the flattened input as a diagonal. 

Return the indices to access the main diagonal of an array. 

Return the indices to access the main diagonal of an ndimensional array. 

Return specified diagonals. 

Calculate the nth discrete difference along the given axis. 

Return the indices of the bins to which each value in input array belongs. 

Returns a true division of the inputs, elementwise. 

Return elementwise quotient and remainder simultaneously. 

Dot product of two arrays. 
alias of 


Split array into multiple subarrays along the 3rd axis (depth). 

Stack arrays in sequence depth wise (along third axis). 

Create a data type object. 

The differences between consecutive elements of an array. 

Evaluates the Einstein summation convention on the operands. 

Evaluates the lowest cost contraction order for an einsum expression by 

Return a new array of given shape and type, filled with zeros. 

Return an array of zeros with the same shape and type as a given array. 

Return (x1 == x2) elementwise. 

Calculate the exponential of all elements in the input array. 

Calculate 2**p for all p in the input array. 

Expand the shape of an array. 

Calculate 

Return the elements of an array that satisfy some condition. 

Return a 2D array with ones on the diagonal and zeros elsewhere. 

Compute the absolute values elementwise. 

Machine limits for floating point types. 

Round to nearest integer towards zero. 

Return indices that are nonzero in the flattened version of a. 

Abstract base class of all scalar types without predefined length. 

Reverse the order of elements in an array along the given axis. 

Flip array in the left/right direction. 

Flip array in the up/down direction. 
alias of 








Abstract base class of all floatingpoint scalar types. 

First array elements raised to powers from second array, elementwise. 

Return the floor of the input, elementwise. 

Return the largest integer smaller or equal to the division of the inputs. 

Elementwise maximum of array elements. 

Elementwise minimum of array elements. 

Return the elementwise remainder of division. 

Decompose the elements of x into mantissa and twos exponent. 

Return a new array of given shape and type, filled with fill_value. 

Return a full array with the same shape and type as a given array. 

Returns the greatest common divisor of 

Return numbers spaced evenly on a log scale (a geometric progression). 

Return the gradient of an Ndimensional array. 

Return the truth value of (x1 > x2) elementwise. 

Return the truth value of (x1 >= x2) elementwise. 

Return the Hamming window. 

Return the Hanning window. 

Compute the Heaviside step function. 

Compute the histogram of a set of data. 

Function to calculate only the edges of the bins used by the histogram 

Compute the bidimensional histogram of two data samples. 

Compute the multidimensional histogram of some data. 

Split an array into multiple subarrays horizontally (columnwise). 

Stack arrays in sequence horizontally (column wise). 

Given the âlegsâ of a right triangle, return its hypotenuse. 

Modified Bessel function of the first kind, order 0. 

Return the identity array. 

Machine limits for integer types. 

Return the imaginary part of the complex argument. 

Test whether each element of a 1D array is also present in a second array. 

Return an array representing the indices of a grid. 

Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floatingpoint numbers. 

Inner product of two arrays. 
alias of 










Abstract base class of all integer scalar types. 

Onedimensional linear interpolation. 

Find the intersection of two arrays. 

Compute bitwise inversion, or bitwise NOT, elementwise. 

Returns a boolean array where two arrays are elementwise equal within a 

Returns a bool array, where True if input element is complex. 

Check for a complex type or an array of complex numbers. 

Test elementwise for finiteness (not infinity or not Not a Number). 

Calculates element in test_elements, broadcasting over element only. 

Test elementwise for positive or negative infinity. 

Test elementwise for NaN and return result as a boolean array. 

Test elementwise for negative infinity, return result as bool array. 

Test elementwise for positive infinity, return result as bool array. 

Returns a bool array, where True if input element is real. 

Return True if x is a not complex type or an array of complex numbers. 

Returns True if the type of element is a scalar type. 

Returns True if first argument is a typecode lower/equal in type hierarchy. 

Determine if the first argument is a subclass of the second argument. 

Check whether or not an object can be iterated over. 

Construct an open mesh from multiple sequences. 

Return the Kaiser window. 

Kronecker product of two arrays. 

Returns the lowest common multiple of 

Returns x1 * 2**x2, elementwise. 

Shift the bits of an integer to the left. 

Return the truth value of (x1 < x2) elementwise. 

Return the truth value of (x1 =< x2) elementwise. 

Perform an indirect stable sort using a sequence of keys. 

Return evenly spaced numbers over a specified interval. 

Load arrays or pickled objects from 

Natural logarithm, elementwise. 

Return the base 10 logarithm of the input array, elementwise. 

Return the natural logarithm of one plus the input array, elementwise. 

Base2 logarithm of x. 
Logarithm of the sum of exponentiations of the inputs. 

Logarithm of the sum of exponentiations of the inputs in base2. 


Compute the truth value of x1 AND x2 elementwise. 

Compute the truth value of NOT x elementwise. 

Compute the truth value of x1 OR x2 elementwise. 

Compute the truth value of x1 XOR x2, elementwise. 

Return numbers spaced evenly on a log scale. 

Return the indices to access (n, n) arrays, given a masking function. 

Matrix product of two arrays. 

Return the maximum of an array or maximum along an axis. 

Elementwise maximum of array elements. 

Compute the arithmetic mean along the specified axis. 

Compute the median along the specified axis. 

Return coordinate matrices from coordinate vectors. 
Return dense multidimensional âmeshgridâ. 


Return the minimum of an array or minimum along an axis. 

Elementwise minimum of array elements. 

Return elementwise remainder of division. 

Return the fractional and integral parts of an array, elementwise. 

Move axes of an array to new positions. 

Return a copy of an array sorted along the first axis. 

Multiply arguments elementwise. 

Return the indices of the maximum values in the specified axis ignoring 

Return the indices of the minimum values in the specified axis ignoring 

Return the cumulative product of array elements over a given axis treating Not a 

Return the cumulative sum of array elements over a given axis treating Not a 

Return the maximum of an array or maximum along an axis, ignoring any 

Compute the arithmetic mean along the specified axis, ignoring NaNs. 

Compute the median along the specified axis, while ignoring NaNs. 

Return minimum of an array or minimum along an axis, ignoring any NaNs. 

Compute the qth percentile of the data along the specified axis, 

Return the product of array elements over a given axis treating Not a 

Compute the qth quantile of the data along the specified axis, 

Compute the standard deviation along the specified axis, while 

Return the sum of array elements over a given axis treating Not a 

Replace NaN with zero and infinity with large finite numbers (default 

Compute the variance along the specified axis, while ignoring NaNs. 



Return the number of dimensions of an array. 

Numerical negative, elementwise. 

Return the next floatingpoint value after x1 towards x2, elementwise. 

Return the indices of the elements that are nonzero. 

Return (x1 != x2) elementwise. 

Abstract base class of all numeric scalar types. 
Any Python object. 

Return open multidimensional âmeshgridâ. 


Return a new array of given shape and type, filled with ones. 

Return an array of ones with the same shape and type as a given array. 

Compute the outer product of two vectors. 

Packs the elements of a binaryvalued array into bits in a uint8 array. 

Pad an array. 

Compute the qth percentile of the data along the specified axis. 

Evaluate a piecewisedefined function. 

Find the coefficients of a polynomial with the given sequence of roots. 

Find the sum of two polynomials. 

Return the derivative of the specified order of a polynomial. 

Return an antiderivative (indefinite integral) of a polynomial. 

Find the product of two polynomials. 

Difference (subtraction) of two polynomials. 

Evaluate a polynomial at specific values. 

Numerical positive, elementwise. 

First array elements raised to powers from second array, elementwise. 

Return the product of array elements over a given axis. 

Return the product of array elements over a given axis. 

Returns the type to which a binary operation should cast its arguments. 

Range of values (maximum  minimum) along an axis. 

Compute the qth quantile of the data along the specified axis. 
Concatenate slices, scalars and arraylike objects along the first axis. 


Convert angles from radians to degrees. 

Convert angles from degrees to radians. 

Return a contiguous flattened array. 

Converts a tuple of index arrays into an array of flat 

Return the real part of the complex argument. 

Return the reciprocal of the argument, elementwise. 

Return elementwise remainder of division. 

Repeat elements of an array. 

Gives a new shape to an array without changing its data. 

Return a new array with the specified shape. 

Returns the type that results from applying the NumPy 

Shift the bits of an integer to the right. 

Round elements of the array to the nearest integer. 

Roll array elements along a given axis. 

Roll the specified axis backwards, until it lies in a given position. 

Return the roots of a polynomial with coefficients given in p. 

Rotate an array by 90 degrees in the plane specified by axes. 

Evenly round to the given number of decimals. 

Evenly round to the given number of decimals. 

Stack arrays in sequence vertically (row wise). 

Save an array to a binary file in NumPy 

Save several arrays into a single file in uncompressed 

Find indices where elements should be inserted to maintain order. 

Return an array drawn from elements in choicelist, depending on conditions. 

Set printing options. 

Find the set difference of two arrays. 

Find the set exclusiveor of two arrays. 

Return the shape of an array. 

Returns an elementwise indication of the sign of a number. 

Returns elementwise True where signbit is set (less than zero). 
Abstract base class of all signed integer scalar types. 


Trigonometric sine, elementwise. 

Return the sinc function. 
alias of 


Hyperbolic sine, elementwise. 

Return the number of elements along a given axis. 

Test whether any array element along a given axis evaluates to True. 

Return a sorted copy of an array. 

Sort a complex array using the real part first, then the imaginary part. 

Split an array into multiple subarrays as views into ary. 

Return the nonnegative squareroot of an array, elementwise. 

Return the elementwise square of the input. 

Remove singledimensional entries from the shape of an array. 

Join a sequence of arrays along a new axis. 

Compute the standard deviation along the specified axis. 

Subtract arguments, elementwise. 

Sum of array elements over a given axis. 

Interchange two axes of an array. 

Take elements from an array along an axis. 

Take values from the input array by matching 1d index and data slices. 

Compute tangent elementwise. 

Compute hyperbolic tangent elementwise. 

Compute tensor dot product along specified axes. 

Construct an array by repeating A the number of times given by reps. 

Return the sum along diagonals of the array. 

Reverse or permute the axes of an array; returns the modified array. 

Integrate along the given axis using the composite trapezoidal rule. 

An array with ones at and below the given diagonal and zeros elsewhere. 

Lower triangle of an array. 

Return the indices for the lowertriangle of an (n, m) array. 

Return the indices for the lowertriangle of arr. 

Trim the leading and/or trailing zeros from a 1D array or sequence. 

Upper triangle of an array. 

Return the indices for the uppertriangle of an (n, m) array. 

Return the indices for the uppertriangle of arr. 

Returns a true division of the inputs, elementwise. 

Return the truncated value of the input, elementwise. 









Find the unique elements of an array. 

Find the union of two arrays. 

Unpacks elements of a uint8 array into a binaryvalued output array. 

Converts a flat index or array of flat indices into a tuple 
Abstract base class of all unsigned integer scalar types. 


Unwrap by changing deltas between values to 2*pi complement. 

Generate a Vandermonde matrix. 

Compute the variance along the specified axis. 

Return the dot product of two vectors. 

Define a vectorized function with broadcasting. 

Split an array into multiple subarrays vertically (rowwise). 

Stack arrays in sequence vertically (row wise). 

Return elements chosen from x or y depending on condition. 

Return a new array of given shape and type, filled with zeros. 

Return an array of zeros with the same shape and type as a given array. 
jax.numpy.fftÂ¶

Compute the onedimensional discrete Fourier Transform. 

Compute the 2dimensional discrete Fourier Transform 

Return the Discrete Fourier Transform sample frequencies. 

Compute the Ndimensional discrete Fourier Transform. 

Shift the zerofrequency component to the center of the spectrum. 

Compute the FFT of a signal that has Hermitian symmetry, i.e., a real 

Compute the onedimensional inverse discrete Fourier Transform. 

Compute the 2dimensional inverse discrete Fourier Transform. 

Compute the Ndimensional inverse discrete Fourier Transform. 

The inverse of fftshift. 

Compute the inverse FFT of a signal that has Hermitian symmetry. 

Compute the inverse of the npoint DFT for real input. 

Compute the 2dimensional inverse FFT of a real array. 

Compute the inverse of the Ndimensional FFT of real input. 

Compute the onedimensional discrete Fourier Transform for real input. 

Compute the 2dimensional FFT of a real array. 

Return the Discrete Fourier Transform sample frequencies 

Compute the Ndimensional discrete Fourier Transform for real input. 
jax.numpy.linalgÂ¶

Cholesky decomposition. 

Compute the condition number of a matrix. 
Compute the determinant of an array. 


Compute the eigenvalues and right eigenvectors of a square array. 

Return the eigenvalues and eigenvectors of a complex Hermitian 

Compute the eigenvalues of a general matrix. 

Compute the eigenvalues of a complex Hermitian or real symmetric matrix. 

Compute the (multiplicative) inverse of a matrix. 

Return the leastsquares solution to a linear matrix equation. 

Raise a square matrix to the (integer) power n. 

Return matrix rank of array using SVD method 

Compute the dot product of two or more arrays in a single function call, 

Matrix or vector norm. 
Compute the (MoorePenrose) pseudoinverse of a matrix. 


Compute the qr factorization of a matrix. 
Compute the sign and (natural) logarithm of the determinant of an array. 


Solve a linear matrix equation, or system of linear scalar equations. 

Singular Value Decomposition. 

Compute the âinverseâ of an Ndimensional array. 

Solve the tensor equation 
JAX DeviceArrayÂ¶
The JAX DeviceArray
is the core array object in JAX: you can
think of it as the equivalent of a numpy.ndarray
backed by a memory buffer
on a single device. Like numpy.ndarray
, most users will not need to
instantiate DeviceArray
objects manually, but rather will create them via
jax.numpy
functions like array()
, arange()
,
linspace()
, and others listed above.

jax.numpy.
DeviceArray
Â¶ alias of
jaxlib.xla_extension.DeviceArrayBase

class
jaxlib.xla_extension.
DeviceArrayBase
Â¶

class
jaxlib.xla_extension.
DeviceArray
Â¶ 
property
T
Â¶ Reverse or permute the axes of an array; returns the modified array.
LAXbackend implementation of
transpose()
.The JAX version of this function will return a copy rather than a view of the input.
Original docstring below.
For an array a with two axes, transpose(a) gives the matrix transpose.
 Parameters
a (array_like) â Input array.
axes (tuple or list of ints, optional) â If specified, it must be a tuple or list which contains a permutation of [0,1,..,N1] where N is the number of axes of a. The iâth axis of the returned array will correspond to the axis numbered
axes[i]
of the input. If not specified, defaults torange(a.ndim)[::1]
, which reverses the order of the axes.
 Returns
p â a with its axes permuted. A view is returned whenever possible.
 Return type

all
(axis=None, out=None, keepdims=None, *, where=None)Â¶ Test whether all array elements along a given axis evaluate to True.
LAXbackend implementation of
all()
.Original docstring below.
 Parameters
a (array_like) â Input array or object that can be converted to an array.
axis (None or int or tuple of ints, optional) â Axis or axes along which a logical AND reduction is performed. The default (
axis=None
) is to perform a logical AND over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.keepdims (bool, optional) â
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the all method of subclasses of ndarray, however any nondefault value will be. If the subclassâ method does not implement keepdims any exceptions will be raised.
 Returns
all â A new boolean or array is returned unless out is specified, in which case a reference to out is returned.
 Return type

any
(axis=None, out=None, keepdims=None, *, where=None)Â¶ Test whether any array element along a given axis evaluates to True.
LAXbackend implementation of
any()
.Original docstring below.
Returns single boolean unless axis is not
None
 Parameters
a (array_like) â Input array or object that can be converted to an array.
axis (None or int or tuple of ints, optional) â Axis or axes along which a logical OR reduction is performed. The default (
axis=None
) is to perform a logical OR over all the dimensions of the input array. axis may be negative, in which case it counts from the last to the first axis.keepdims (bool, optional) â
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the any method of subclasses of ndarray, however any nondefault value will be. If the subclassâ method does not implement keepdims any exceptions will be raised.
 Returns
any â A new boolean or ndarray is returned unless out is specified, in which case a reference to out is returned.
 Return type

argmax
(axis=None, out=None)Â¶ Returns the indices of the maximum values along an axis.
LAXbackend implementation of
argmax()
.Original docstring below.
 Parameters
a (array_like) â Input array.
axis (int, optional) â By default, the index is into the flattened array, otherwise along the specified axis.
 Returns
index_array â Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.
 Return type
ndarray of ints

argmin
(axis=None, out=None)Â¶ Returns the indices of the minimum values along an axis.
LAXbackend implementation of
argmin()
.Original docstring below.
 Parameters
a (array_like) â Input array.
axis (int, optional) â By default, the index is into the flattened array, otherwise along the specified axis.
 Returns
index_array â Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.
 Return type
ndarray of ints

argpartition
(**kwargs)Â¶ Perform an indirect partition along the given axis using the
LAXbackend implementation of
argpartition()
.* This function is not yet implemented by jax.numpy, and will raise NotImplementedError *
Original docstring below.
algorithm specified by the kind keyword. It returns an array of indices of the same shape as a that index data along the given axis in partitioned order.
New in version 1.8.0.
 Parameters
a (array_like) â Array to sort.
kth (int or sequence of ints) â Element index to partition by. The kth element will be in its final sorted position and all smaller elements will be moved before it and all larger elements behind it. The order all elements in the partitions is undefined. If provided with a sequence of kth it will partition all of them into their sorted position at once.
axis (int or None, optional) â Axis along which to sort. The default is 1 (the last axis). If None, the flattened array is used.
kind ({'introselect'}, optional) â Selection algorithm. Default is âintroselectâ
order (str or list of str, optional) â When a is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string, and not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties.
 Returns
index_array â Array of indices that partition a along the specified axis. If a is onedimensional,
a[index_array]
yields a partitioned a. More generally,np.take_along_axis(a, index_array, axis=a)
always yields the partitioned a, irrespective of dimensionality. Return type

argsort
(axis= 1, kind='quicksort', order=None)Â¶ Returns the indices that would sort an array.
LAXbackend implementation of
argsort()
.Original docstring below.
Perform an indirect sort along the given axis using the algorithm specified by the kind keyword. It returns an array of indices of the same shape as a that index data along the given axis in sorted order.
 Parameters
a (array_like) â Array to sort.
axis (int or None, optional) â Axis along which to sort. The default is 1 (the last axis). If None, the flattened array is used.
kind ({'quicksort', 'mergesort', 'heapsort', 'stable'}, optional) â
Sorting algorithm. The default is âquicksortâ. Note that both âstableâ and âmergesortâ use timsort under the covers and, in general, the actual implementation will vary with data type. The âmergesortâ option is retained for backwards compatibility.
Changed in version 1.15.0.: The âstableâ option was added.
order (str or list of str, optional) â When a is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string, and not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties.
 Returns
index_array â Array of indices that sort a along the specified axis. If a is onedimensional,
a[index_array]
yields a sorted a. More generally,np.take_along_axis(a, index_array, axis=axis)
always yields the sorted a, irrespective of dimensionality. Return type

property
at
Â¶ Indexable helper object to call indexed update functions.
The
at
property is syntactic sugar for calling the indexed update functions defined injax.ops
, and acts as a pure equivalent of inplace modificatons. For further information, see Indexed Update Operators.In particular:
x = x.at[idx].set(y)
is a pure equivalent ofx[idx] = y
.x = x.at[idx].add(y)
is a pure equivalent ofx[idx] += y
.x = x.at[idx].multiply(y)
(akamul
) is a pure equivalent ofx[idx] *= y
.x = x.at[idx].divide(y)
is a pure equivalent ofx[idx] /= y
.x = x.at[idx].power(y)
is a pure equivalent ofx[idx] **= y
.x = x.at[idx].min(y)
is a pure equivalent ofx[idx] = minimum(x[idx], y)
.x = x.at[idx].max(y)
is a pure equivalent ofx[idx] = maximum(x[idx], y)
.

block_host_until_ready
()Â¶ (self: xla::PyBuffer::pyobject) > Status

block_until_ready
()Â¶ (self: xla::PyBuffer::pyobject) > StatusOr[xla::PyBuffer::pyobject]

broadcast
(sizes)Â¶ Broadcasts an array, adding new major dimensions.
Wraps XLAâs Broadcast operator.

broadcast_in_dim
(shape, broadcast_dimensions)Â¶ Wraps XLAâs BroadcastInDim operator.

clip
(a_min=None, a_max=None, out=None)Â¶ Clip (limit) the values in an array.
LAXbackend implementation of
clip()
.Original docstring below.
Given an interval, values outside the interval are clipped to the interval edges. For example, if an interval of
[0, 1]
is specified, values smaller than 0 become 0, and values larger than 1 become 1.Equivalent to but faster than
np.minimum(a_max, np.maximum(a, a_min))
.No check is performed to ensure
a_min < a_max
. Parameters
a (array_like) â Array containing elements to clip.
a_min (scalar or array_like or None) â Minimum value. If None, clipping is not performed on lower interval edge. Not more than one of a_min and a_max may be None.
a_max (scalar or array_like or None) â Maximum value. If None, clipping is not performed on upper interval edge. Not more than one of a_min and a_max may be None. If a_min or a_max are array_like, then the three arrays will be broadcasted to match their shapes.
 Returns
clipped_array â An array with the elements of a, but where values < a_min are replaced with a_min, and those > a_max with a_max.
 Return type

clone
()Â¶ (self: xla::PyBuffer::pyobject) > xla::PyBuffer::pyobject

conj
()Â¶ Return the complex conjugate, elementwise.
LAXbackend implementation of
conjugate()
.Original docstring below.
The complex conjugate of a complex number is obtained by changing the sign of its imaginary part.
 Parameters
x (array_like) â Input value.
 Returns
y â The complex conjugate of x, with same dtype as y. This is a scalar if x is a scalar.
 Return type

conjugate
()Â¶ Return the complex conjugate, elementwise.
LAXbackend implementation of
conjugate()
.Original docstring below.
The complex conjugate of a complex number is obtained by changing the sign of its imaginary part.
 Parameters
x (array_like) â Input value.
 Returns
y â The complex conjugate of x, with same dtype as y. This is a scalar if x is a scalar.
 Return type

copy
()Â¶ Returns an ndarray (backed by host memory, not device memory).

copy_to_device
()Â¶ (self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) > StatusOr[object]

copy_to_host_async
()Â¶ (self: xla::PyBuffer::pyobject) > Status

cumprod
(axis=None, dtype=None, out=None)Â¶ Return the cumulative product of elements along a given axis.
LAXbackend implementation of
cumprod()
.Original docstring below.
 Parameters
a (array_like) â Input array.
axis (int, optional) â Axis along which the cumulative product is computed. By default the input is flattened.
dtype (dtype, optional) â Type of the returned array, as well as of the accumulator in which the elements are multiplied. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used instead.
 Returns
cumprod â A new array holding the result is returned unless out is specified, in which case a reference to out is returned.
 Return type

cumsum
(axis=None, dtype=None, out=None)Â¶ Return the cumulative sum of the elements along a given axis.
LAXbackend implementation of
cumsum()
.Original docstring below.
 Parameters
a (array_like) â Input array.
axis (int, optional) â Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array.
dtype (dtype, optional) â Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of a, unless a has an integer dtype with a precision less than that of the default platform integer. In that case, the default platform integer is used.
 Returns
cumsum_along_axis â A new array holding the result is returned unless out is specified, in which case a reference to out is returned. The result has the same size as a, and the same shape as a if axis is not None or a is a 1d array.
 Return type
ndarray.

delete
()Â¶ (self: xla::PyBuffer::pyobject) > None

device
()Â¶ (self: xla::PyBuffer::pyobject) > jaxlib.xla_extension.Device

diagonal
(offset=0, axis1=0, axis2=1)Â¶ Return specified diagonals.
LAXbackend implementation of
diagonal()
.The JAX version of this function will return a copy rather than a view of the input.
Original docstring below.
If a is 2D, returns the diagonal of a with the given offset, i.e., the collection of elements of the form
a[i, i+offset]
. If a has more than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2D subarray whose diagonal is returned. The shape of the resulting array can be determined by removing axis1 and axis2 and appending an index to the right equal to the size of the resulting diagonals.In versions of NumPy prior to 1.7, this function always returned a new, independent array containing a copy of the values in the diagonal.
In NumPy 1.7 and 1.8, it continues to return a copy of the diagonal, but depending on this fact is deprecated. Writing to the resulting array continues to work as it used to, but a FutureWarning is issued.
Starting in NumPy 1.9 it returns a readonly view on the original array. Attempting to write to the resulting array will produce an error.
In some future release, it will return a read/write view and writing to the returned array will alter your original array. The returned array will have the same type as the input array.
If you donât write to the array returned by this function, then you can just ignore all of the above.
If you depend on the current behavior, then we suggest copying the returned array explicitly, i.e., use
np.diagonal(a).copy()
instead of justnp.diagonal(a)
. This will work with both past and future versions of NumPy. Parameters
a (array_like) â Array from which the diagonals are taken.
offset (int, optional) â Offset of the diagonal from the main diagonal. Can be positive or negative. Defaults to main diagonal (0).
axis1 (int, optional) â Axis to be used as the first axis of the 2D subarrays from which the diagonals should be taken. Defaults to first axis (0).
axis2 (int, optional) â Axis to be used as the second axis of the 2D subarrays from which the diagonals should be taken. Defaults to second axis (1).
 Returns
array_of_diagonals â If a is 2D, then a 1D array containing the diagonal and of the same type as a is returned unless a is a matrix, in which case a 1D array rather than a (2D) matrix is returned in order to maintain backward compatibility.
If
a.ndim > 2
, then the dimensions specified by axis1 and axis2 are removed, and a new axis inserted at the end corresponding to the diagonal. Return type

dot
(b, *, precision=None)Â¶ Dot product of two arrays. Specifically,
LAXbackend implementation of
dot()
.In addition to the original NumPy arguments listed below, also supports
precision
for extra control over matrixmultiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
If both a and b are 1D arrays, it is inner product of vectors (without complex conjugation).
If both a and b are 2D arrays, it is matrix multiplication, but using
matmul()
ora @ b
is preferred.If either a or b is 0D (scalar), it is equivalent to
multiply()
and usingnumpy.multiply(a, b)
ora * b
is preferred.If a is an ND array and b is a 1D array, it is a sum product over the last axis of a and b.
If a is an ND array and b is an MD array (where
M>=2
), it is a sum product over the last axis of a and the secondtolast axis of b:dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
 Parameters
a (array_like) â First argument.
b (array_like) â Second argument.
 Returns
output â Returns the dot product of a and b. If a and b are both scalars or both 1D arrays then a scalar is returned; otherwise an array is returned. If out is given, then it is returned.
 Return type

flatten
(order='C')Â¶ Return a contiguous flattened array.
LAXbackend implementation of
ravel()
.The JAX version of this function will return a copy rather than a view of the input.
Original docstring below.
A 1D array, containing the elements of the input, is returned. A copy is made only if needed.
As of NumPy 1.10, the returned array will have the same type as the input array. (for example, a masked array will be returned for a masked array input)
 Parameters
a (array_like) â Input array. The elements in a are read in the order specified by order, and packed as a 1D array.
order ({'C','F', 'A', 'K'}, optional) â The elements of a are read using this index order. âCâ means to index the elements in rowmajor, Cstyle order, with the last axis index changing fastest, back to the first axis index changing slowest. âFâ means to index the elements in columnmajor, Fortranstyle order, with the first index changing fastest, and the last index changing slowest. Note that the âCâ and âFâ options take no account of the memory layout of the underlying array, and only refer to the order of axis indexing. âAâ means to read the elements in Fortranlike index order if a is Fortran contiguous in memory, Clike order otherwise. âKâ means to read the elements in the order they occur in memory, except for reversing the data when strides are negative. By default, âCâ index order is used.
 Returns
y â y is an array of the same subtype as a, with shape
(a.size,)
. Note that matrices are special cased for backward compatibility, if a is a matrix, then y is a 1D ndarray. Return type
array_like

property
imag
Â¶ Return the imaginary part of the complex argument.
LAXbackend implementation of
imag()
.Original docstring below.
 Parameters
val (array_like) â Input array.
 Returns
out â The imaginary component of the complex argument. If val is real, the type of val is used for the output. If val has complex elements, the returned type is float.
 Return type
ndarray or scalar

is_deleted
()Â¶ (self: xla::PyBuffer::pyobject) > bool

max
(axis=None, out=None, keepdims=None, initial=None, where=None)Â¶ Return the maximum of an array or maximum along an axis.
LAXbackend implementation of
amax()
.Original docstring below.
 Parameters
a (array_like) â Input data.
axis (None or int or tuple of ints, optional) â Axis or axes along which to operate. By default, flattened input is used.
keepdims (bool, optional) â
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the amax method of subclasses of ndarray, however any nondefault value will be. If the subclassâ method does not implement keepdims any exceptions will be raised.
initial (scalar, optional) â The minimum value of an output element. Must be present to allow computation on empty slice. See ~numpy.ufunc.reduce for details.
where (array_like of bool, optional) â Elements to compare for the maximum. See ~numpy.ufunc.reduce for details.
 Returns
amax â Maximum of a. If axis is None, the result is a scalar value. If axis is given, the result is an array of dimension
a.ndim  1
. Return type
ndarray or scalar

mean
(axis=None, dtype=None, out=None, keepdims=False, *, where=None)Â¶ Compute the arithmetic mean along the specified axis.
LAXbackend implementation of
mean()
.Original docstring below.
Returns the average of the array elements. The average is taken over the flattened array by default, otherwise over the specified axis. float64 intermediate and return values are used for integer inputs.
 Parameters
a (array_like) â Array containing numbers whose mean is desired. If a is not an array, a conversion is attempted.
axis (None or int or tuple of ints, optional) â Axis or axes along which the means are computed. The default is to compute the mean of the flattened array.
dtype (datatype, optional) â Type to use in computing the mean. For integer inputs, the default is float64; for floating point inputs, it is the same as the input dtype.
keepdims (bool, optional) â
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the mean method of subclasses of ndarray, however any nondefault value will be. If the subclassâ method does not implement keepdims any exceptions will be raised.
 Returns
m â If out=None, returns a new array containing the mean values, otherwise a reference to the output array is returned.
 Return type
ndarray, see dtype parameter above

min
(axis=None, out=None, keepdims=None, initial=None, where=None)Â¶ Return the minimum of an array or minimum along an axis.
LAXbackend implementation of
amin()
.Original docstring below.
 Parameters
a (array_like) â Input data.
axis (None or int or tuple of ints, optional) â Axis or axes along which to operate. By default, flattened input is used.
keepdims (bool, optional) â
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the amin method of subclasses of ndarray, however any nondefault value will be. If the subclassâ method does not implement keepdims any exceptions will be raised.
initial (scalar, optional) â The maximum value of an output element. Must be present to allow computation on empty slice. See ~numpy.ufunc.reduce for details.
where (array_like of bool, optional) â Elements to compare for the minimum. See ~numpy.ufunc.reduce for details.
 Returns
amin â Minimum of a. If axis is None, the result is a scalar value. If axis is given, the result is an array of dimension
a.ndim  1
. Return type
ndarray or scalar

nonzero
(*, size=None)Â¶ Return the indices of the elements that are nonzero.
LAXbackend implementation of
nonzero()
.Because the size of the output of
nonzero
is datadependent, the function is not typically compatible with JIT. The JAX version adds the optional size argument which specifies the size of the output arrays: it must be specified statically forjnp.nonzero
to be traced. If specified, the first size nonzero elements will be returned; if there are fewer nonzero elements than size indicates, the index arrays will be zeropadded.Original docstring below.
Returns a tuple of arrays, one for each dimension of a, containing the indices of the nonzero elements in that dimension. The values in a are always tested and returned in rowmajor, Cstyle order.
To group the indices by element, rather than dimension, use argwhere, which returns a row for each nonzero element.
Note
When called on a zerod array or scalar,
nonzero(a)
is treated asnonzero(atleast1d(a))
.Deprecated since version 1.17.0: Use atleast1d explicitly if this behavior is deliberate.
 Parameters
a (array_like) â Input array.
 Returns
tuple_of_arrays â Indices of elements that are nonzero.
 Return type

on_device_size_in_bytes
()Â¶ (self: xla::PyBuffer::pyobject) > int

platform
()Â¶ (self: xla::PyBuffer::pyobject) > str

prod
(axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)Â¶ Return the product of array elements over a given axis.
LAXbackend implementation of
prod()
.Original docstring below.
 Parameters
a (array_like) â Input data.
axis (None or int or tuple of ints, optional) â Axis or axes along which a product is performed. The default, axis=None, will calculate the product of all the elements in the input array. If axis is negative it counts from the last to the first axis.
dtype (dtype, optional) â The type of the returned array, as well as of the accumulator in which the elements are multiplied. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
keepdims (bool, optional) â
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the prod method of subclasses of ndarray, however any nondefault value will be. If the subclassâ method does not implement keepdims any exceptions will be raised.
initial (scalar, optional) â The starting value for this product. See ~numpy.ufunc.reduce for details.
where (array_like of bool, optional) â Elements to include in the product. See ~numpy.ufunc.reduce for details.
 Returns
product_along_axis â An array shaped as a but with the specified axis removed. Returns a reference to out if specified.
 Return type
ndarray, see dtype parameter above.

ptp
(axis=None, out=None, keepdims=False)Â¶ Range of values (maximum  minimum) along an axis.
LAXbackend implementation of
ptp()
.Original docstring below.
The name of the function comes from the acronym for âpeak to peakâ.
Warning
ptp preserves the data type of the array. This means the return value for an input of signed integers with n bits (e.g. np.int8, np.int16, etc) is also a signed integer with n bits. In that case, peaktopeak values greater than
2**(n1)1
will be returned as negative values. An example with a workaround is shown below. Parameters
a (array_like) â Input values.
axis (None or int or tuple of ints, optional) â Axis along which to find the peaks. By default, flatten the array. axis may be negative, in which case it counts from the last to the first axis.
keepdims (bool, optional) â
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the ptp method of subclasses of ndarray, however any nondefault value will be. If the subclassâ method does not implement keepdims any exceptions will be raised.
 Returns
ptp â A new array holding the result, unless out was specified, in which case a reference to out is returned.
 Return type

ravel
(order='C')Â¶ Return a contiguous flattened array.
LAXbackend implementation of
ravel()
.The JAX version of this function will return a copy rather than a view of the input.
Original docstring below.
A 1D array, containing the elements of the input, is returned. A copy is made only if needed.
As of NumPy 1.10, the returned array will have the same type as the input array. (for example, a masked array will be returned for a masked array input)
 Parameters
a (array_like) â Input array. The elements in a are read in the order specified by order, and packed as a 1D array.
order ({'C','F', 'A', 'K'}, optional) â The elements of a are read using this index order. âCâ means to index the elements in rowmajor, Cstyle order, with the last axis index changing fastest, back to the first axis index changing slowest. âFâ means to index the elements in columnmajor, Fortranstyle order, with the first index changing fastest, and the last index changing slowest. Note that the âCâ and âFâ options take no account of the memory layout of the underlying array, and only refer to the order of axis indexing. âAâ means to read the elements in Fortranlike index order if a is Fortran contiguous in memory, Clike order otherwise. âKâ means to read the elements in the order they occur in memory, except for reversing the data when strides are negative. By default, âCâ index order is used.
 Returns
y â y is an array of the same subtype as a, with shape
(a.size,)
. Note that matrices are special cased for backward compatibility, if a is a matrix, then y is a 1D ndarray. Return type
array_like

property
real
Â¶ Return the real part of the complex argument.
LAXbackend implementation of
real()
.Original docstring below.
 Parameters
val (array_like) â Input array.
 Returns
out â The real component of the complex argument. If val is real, the type of val is used for the output. If val has complex elements, the returned type is float.
 Return type
ndarray or scalar

repeat
(repeats, axis=None, *, total_repeat_length=None)Â¶ Repeat elements of an array.
LAXbackend implementation of
repeat()
.Jax adds the optional total_repeat_length parameter which specifies the total number of repeat, and defaults to sum(repeats). It must be specified for repeat to be compilable. If sum(repeats) is larger than the specified total_repeat_length the remaining values will be discarded. In the case of sum(repeats) being smaller than the specified target length, the final value will be repeated.
Original docstring below.
 Parameters
a (array_like) â Input array.
repeats (int or array of ints) â The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
axis (int, optional) â The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
 Returns
repeated_array â Output array which has the same shape as a, except along the given axis.
 Return type

round
(decimals=0, out=None)Â¶ Evenly round to the given number of decimals.
LAXbackend implementation of
around()
.Original docstring below.
 Parameters
a (array_like) â Input data.
decimals (int, optional) â Number of decimal places to round to (default: 0). If decimals is negative, it specifies the number of positions to the left of the decimal point.
 Returns
rounded_array â An array of the same type as a, containing the rounded values. Unless out was specified, a new array is created. A reference to the result is returned.
The real and imaginary parts of complex numbers are rounded separately. The result of rounding a float is a float.
 Return type
References
 1
âLecture Notes on the Status of IEEE 754â, William Kahan, https://people.eecs.berkeley.edu/~wkahan/ieee754status/IEEE754.PDF
 2
âHow Futile are Mindless Assessments of Roundoff in FloatingPoint Computation?â, William Kahan, https://people.eecs.berkeley.edu/~wkahan/Mindless.pdf

searchsorted
(v, side='left', sorter=None)Â¶ Find indices where elements should be inserted to maintain order.
LAXbackend implementation of
searchsorted()
.Original docstring below.
Find the indices into a sorted array a such that, if the corresponding elements in v were inserted before the indices, the order of a would be preserved.
Assuming that a is sorted:
side
returned index i satisfies
left
a[i1] < v <= a[i]
right
a[i1] <= v < a[i]
 Parameters
a (1D array_like) â Input array. If sorter is None, then it must be sorted in ascending order, otherwise sorter must be an array of indices that sort it.
v (array_like) â Values to insert into a.
side ({'left', 'right'}, optional) â If âleftâ, the index of the first suitable location found is given. If ârightâ, return the last such index. If there is no suitable index, return either 0 or N (where N is the length of a).
 Returns
indices â Array of insertion points with the same shape as v.
 Return type
array of ints

sort
(axis= 1, kind='quicksort', order=None)Â¶ Return a sorted copy of an array.
LAXbackend implementation of
sort()
.Original docstring below.
 Parameters
a (array_like) â Array to be sorted.
axis (int or None, optional) â Axis along which to sort. If None, the array is flattened before sorting. The default is 1, which sorts along the last axis.
kind ({'quicksort', 'mergesort', 'heapsort', 'stable'}, optional) â
Sorting algorithm. The default is âquicksortâ. Note that both âstableâ and âmergesortâ use timsort or radix sort under the covers and, in general, the actual implementation will vary with data type. The âmergesortâ option is retained for backwards compatibility.
Changed in version 1.15.0.: The âstableâ option was added.
order (str or list of str, optional) â When a is an array with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string, and not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties.
 Returns
sorted_array â Array of the same type and shape as a.
 Return type

split
(indices_or_sections, axis=0)Â¶ Split an array into multiple subarrays as views into ary.
LAXbackend implementation of
split()
.The JAX version of this function will return a copy rather than a view of the input.
Original docstring below.
 Parameters
ary (ndarray) â Array to be divided into subarrays.
indices_or_sections (int or 1D array) â
If indices_or_sections is an integer, N, the array will be divided into N equal arrays along axis. If such a split is not possible, an error is raised.
If indices_or_sections is a 1D array of sorted integers, the entries indicate where along axis the array is split. For example,
[2, 3]
would, foraxis=0
, result inary[:2]
ary[2:3]
ary[3:]
If an index exceeds the dimension of the array along axis, an empty subarray is returned correspondingly.
axis (int, optional) â The axis along which to split, default is 0.
 Returns
subarrays â A list of subarrays as views into ary.
 Return type
list of ndarrays

squeeze
(axis=None)Â¶ Remove singledimensional entries from the shape of an array.
LAXbackend implementation of
squeeze()
.The JAX version of this function will return a copy rather than a view of the input.
Original docstring below.
 Parameters
 Returns
squeezed â The input array, but with all or a subset of the dimensions of length 1 removed. This is always a itself or a view into a. Note that if all axes are squeezed, the result is a 0d array and not a scalar.
 Return type

std
(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None)Â¶ Compute the standard deviation along the specified axis.
LAXbackend implementation of
std()
.Original docstring below.
Returns the standard deviation, a measure of the spread of a distribution, of the array elements. The standard deviation is computed for the flattened array by default, otherwise over the specified axis.
 Parameters
a (array_like) â Calculate the standard deviation of these values.
axis (None or int or tuple of ints, optional) â Axis or axes along which the standard deviation is computed. The default is to compute the standard deviation of the flattened array.
dtype (dtype, optional) â Type to use in computing the standard deviation. For arrays of integer type the default is float64, for arrays of float types it is the same as the array type.
ddof (int, optional) â Means Delta Degrees of Freedom. The divisor used in calculations is
N  ddof
, whereN
represents the number of elements. By default ddof is zero.keepdims (bool, optional) â
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the std method of subclasses of ndarray, however any nondefault value will be. If the subclassâ method does not implement keepdims any exceptions will be raised.
 Returns
standard_deviation â If out is None, return a new array containing the standard deviation, otherwise return a reference to the output array.
 Return type
ndarray, see dtype parameter above.

sum
(axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None)Â¶ Sum of array elements over a given axis.
LAXbackend implementation of
sum()
.Original docstring below.
 Parameters
a (array_like) â Elements to sum.
axis (None or int or tuple of ints, optional) â Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input array. If axis is negative it counts from the last to the first axis.
dtype (dtype, optional) â The type of the returned array and of the accumulator in which the elements are summed. The dtype of a is used by default unless a has an integer dtype of less precision than the default platform integer. In that case, if a is signed then the platform integer is used while if a is unsigned then an unsigned integer of the same precision as the platform integer is used.
keepdims (bool, optional) â
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the sum method of subclasses of ndarray, however any nondefault value will be. If the subclassâ method does not implement keepdims any exceptions will be raised.
initial (scalar, optional) â Starting value for the sum. See ~numpy.ufunc.reduce for details.
where (array_like of bool, optional) â Elements to include in the sum. See ~numpy.ufunc.reduce for details.
 Returns
sum_along_axis â An array with the same shape as a, with the specified axis removed. If a is a 0d array, or if axis is None, a scalar is returned. If an output array is specified, a reference to out is returned.
 Return type

swapaxes
(axis1, axis2)Â¶ Interchange two axes of an array.
LAXbackend implementation of
swapaxes()
.The JAX version of this function will return a copy rather than a view of the input.
Original docstring below.
 Parameters
 Returns
a_swapped â For NumPy >= 1.10.0, if a is an ndarray, then a view of a is returned; otherwise a new array is created. For earlier NumPy versions a view of a is returned only if the order of the axes is changed, otherwise the input array is returned.
 Return type

take
(indices, axis=None, out=None, mode=None)Â¶ Take elements from an array along an axis.
LAXbackend implementation of
take()
.Original docstring below.
When axis is not None, this function does the same thing as âfancyâ indexing (indexing arrays using arrays); however, it can be easier to use if you need elements along a given axis. A call such as
np.take(arr, indices, axis=3)
is equivalent toarr[:,:,:,indices,...]
.Explained without fancy indexing, this is equivalent to the following use of ndindex, which sets each of
ii
,jj
, andkk
to a tuple of indices:Ni, Nk = a.shape[:axis], a.shape[axis+1:] Nj = indices.shape for ii in ndindex(Ni): for jj in ndindex(Nj): for kk in ndindex(Nk): out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
 Parameters
a (array_like (Ni..., M, Nk...)) â The source array.
indices (array_like (Nj...)) â The indices of the values to extract.
axis (int, optional) â The axis over which to select values. By default, the flattened input array is used.
mode ({'raise', 'wrap', 'clip'}, optional) â
Specifies how outofbounds indices will behave.
âraiseâ â raise an error (default)
âwrapâ â wrap around
âclipâ â clip to the range
âclipâ mode means that all indices that are too large are replaced by the index that addresses the last element along that axis. Note that this disables indexing with negative numbers.
 Returns
out â The returned array has the same type as a.
 Return type
ndarray (NiâŠ, NjâŠ, NkâŠ)

tile
(reps)Â¶ Construct an array by repeating A the number of times given by reps.
LAXbackend implementation of
tile()
.Original docstring below.
If reps has length
d
, the result will have dimension ofmax(d, A.ndim)
.If
A.ndim < d
, A is promoted to be ddimensional by prepending new axes. So a shape (3,) array is promoted to (1, 3) for 2D replication, or shape (1, 1, 3) for 3D replication. If this is not the desired behavior, promote A to ddimensions manually before calling this function.If
A.ndim > d
, reps is promoted to A.ndim by prepending 1âs to it. Thus for an A of shape (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2).Note : Although tile may be used for broadcasting, it is strongly recommended to use numpyâs broadcasting operations and functions.
 Parameters
A (array_like) â The input array.
reps (array_like) â The number of repetitions of A along each axis.
 Returns
c â The tiled output array.
 Return type

to_py
()Â¶ (self: xla::PyBuffer::pyobject) > StatusOr[object]

trace
(offset=0, axis1=0, axis2=1, dtype=None, out=None)Â¶ Return the sum along diagonals of the array.
LAXbackend implementation of
trace()
.Original docstring below.
If a is 2D, the sum along its diagonal with the given offset is returned, i.e., the sum of elements
a[i,i+offset]
for all i.If a has more than two dimensions, then the axes specified by axis1 and axis2 are used to determine the 2D subarrays whose traces are returned. The shape of the resulting array is the same as that of a with axis1 and axis2 removed.
 Parameters
a (array_like) â Input array, from which the diagonals are taken.
offset (int, optional) â Offset of the diagonal from the main diagonal. Can be both positive and negative. Defaults to 0.
axis1 (int, optional) â Axes to be used as the first and second axis of the 2D subarrays from which the diagonals should be taken. Defaults are the first two axes of a.
axis2 (int, optional) â Axes to be used as the first and second axis of the 2D subarrays from which the diagonals should be taken. Defaults are the first two axes of a.
dtype (dtype, optional) â Determines the datatype of the returned array and of the accumulator where the elements are summed. If dtype has the value None and a is of integer type of precision less than the default integer precision, then the default integer precision is used. Otherwise, the precision is the same as that of a.
 Returns
sum_along_diagonals â If a is 2D, the sum along the diagonal is returned. If a has larger dimensions, then an array of sums along diagonals is returned.
 Return type

unsafe_buffer_pointer
()Â¶ (self: xla::PyBuffer::pyobject) > StatusOr[int]

var
(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None)Â¶ Compute the variance along the specified axis.
LAXbackend implementation of
var()
.Original docstring below.
Returns the variance of the array elements, a measure of the spread of a distribution. The variance is computed for the flattened array by default, otherwise over the specified axis.
 Parameters
a (array_like) â Array containing numbers whose variance is desired. If a is not an array, a conversion is attempted.
axis (None or int or tuple of ints, optional) â Axis or axes along which the variance is computed. The default is to compute the variance of the flattened array.
dtype (datatype, optional) â Type to use in computing the variance. For arrays of integer type the default is float64; for arrays of float types it is the same as the array type.
ddof (int, optional) â âDelta Degrees of Freedomâ: the divisor used in the calculation is
N  ddof
, whereN
represents the number of elements. By default ddof is zero.keepdims (bool, optional) â
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the var method of subclasses of ndarray, however any nondefault value will be. If the subclassâ method does not implement keepdims any exceptions will be raised.
 Returns
variance â If
out=None
, returns a new array containing the variance; otherwise, a reference to the output array is returned. Return type
ndarray, see dtype parameter above

xla_dynamic_shape
()Â¶ (self: xla::PyBuffer::pyobject) > StatusOr[jaxlib.xla_extension.Shape]

xla_shape
()Â¶ (self: xla::PyBuffer::pyobject) > jaxlib.xla_extension.Shape

property