# jax.numpy package¶

Implements the NumPy API, using the primitives in jax.lax.

While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.

• Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place cannot be implemented in JAX. However, often JAX is able to provide a alternative API that is purely functional. For example, instead of in-place array updates (x[i] = y), JAX provides an alternative pure indexed update function jax.ops.index_update().

• NumPy is very aggressive at promoting values to float64 type. JAX sometimes is less aggressive about type promotion.

A small number of NumPy operations that have data-dependent output shapes are incompatible with jax.jit() compilation. The XLA compiler requires that shapes of arrays be known at compile time. While it would be possible to provide a JAX implementation of an API such as numpy.nonzero(), we would be unable to JIT-compile it because the shape of its output depends on the contents of the input data.

Not every function in NumPy is implemented; contributions are welcome!

 Calculate the absolute value element-wise. Calculate the absolute value element-wise. add(x1, x2) Add arguments element-wise. all(a[, axis, out, keepdims]) Test whether all array elements along a given axis evaluate to True. allclose(a, b[, rtol, atol, equal_nan]) Returns True if two arrays are element-wise equal within a tolerance. alltrue(a[, axis, out, keepdims]) Test whether all array elements along a given axis evaluate to True. amax(a[, axis, out, keepdims, initial, where]) Return the maximum of an array or maximum along an axis. amin(a[, axis, out, keepdims, initial, where]) Return the minimum of an array or minimum along an axis. Return the angle of the complex argument. any(a[, axis, out, keepdims]) Test whether any array element along a given axis evaluates to True. append(arr, values[, axis]) Append values to the end of an array. apply_along_axis(func1d, axis, arr, *args, …) Apply a function to 1-D slices along the given axis. apply_over_axes(func, a, axes) Apply a function repeatedly over multiple axes. arange(start[, stop, step, dtype]) Return evenly spaced values within a given interval. Trigonometric inverse cosine, element-wise. Inverse hyperbolic cosine, element-wise. Inverse sine, element-wise. Inverse hyperbolic sine element-wise. Trigonometric inverse tangent, element-wise. arctan2(x1, x2) Element-wise arc tangent of x1/x2 choosing the quadrant correctly. Inverse hyperbolic tangent element-wise. argmax(a[, axis, out]) Returns the indices of the maximum values along an axis. argmin(a[, axis, out]) Returns the indices of the minimum values along an axis. argsort(a[, axis, kind, order]) Returns the indices that would sort an array. Find the indices of array elements that are non-zero, grouped by element. around(a[, decimals, out]) Round an array to the given number of decimals. array(object[, dtype, copy, order, ndmin]) Create an array. array_equal(a1, a2[, equal_nan]) True if two arrays have the same shape and elements, False otherwise. array_equiv(a1, a2) Returns True if input arrays are shape consistent and all elements equal. array_repr(arr[, max_line_width, precision, …]) Return the string representation of an array. array_split(ary, indices_or_sections[, axis]) Split an array into multiple sub-arrays. array_str(a[, max_line_width, precision, …]) Return a string representation of the data in an array. asarray(a[, dtype, order]) Convert the input to an array. atleast_1d(*arys) Convert inputs to arrays with at least one dimension. atleast_2d(*arys) View inputs as arrays with at least two dimensions. atleast_3d(*arys) View inputs as arrays with at least three dimensions. average(a[, axis, weights, returned]) Compute the weighted average along the specified axis. bartlett(*args, **kwargs) Return the Bartlett window. bincount(x[, weights, minlength, length]) Count number of occurrences of each value in array of non-negative ints. bitwise_and(x1, x2) Compute the bit-wise AND of two arrays element-wise. Compute bit-wise inversion, or bit-wise NOT, element-wise. bitwise_or(x1, x2) Compute the bit-wise OR of two arrays element-wise. bitwise_xor(x1, x2) Compute the bit-wise XOR of two arrays element-wise. blackman(*args, **kwargs) Return the Blackman window. block(arrays) Assemble an nd-array from nested lists of blocks. bool_ broadcast_arrays(*args) Like Numpy’s broadcast_arrays but doesn’t return views. broadcast_to(arr, shape) Broadcast an array to a new shape. can_cast(from_, to[, casting]) Returns True if cast between data types can occur according to the casting rule. Return the cube-root of an array, element-wise. cdouble alias of jax._src.numpy.lax_numpy.complex128 Return the ceiling of the input, element-wise. character Abstract base class of all character string scalar types. choose(a, choices[, out, mode]) Construct an array from an index array and a set of arrays to choose from. clip(a[, a_min, a_max, out]) Clip (limit) the values in an array. Stack 1-D arrays as columns into a 2-D array. complex_ alias of jax._src.numpy.lax_numpy.complex128 complex128 complex64 complexfloating Abstract base class of all complex number scalar types that are made up of floating-point numbers. ComplexWarning The warning raised when casting a complex dtype to a real dtype. compress(condition, a[, axis, out]) Return selected slices of an array along given axis. concatenate(arrays[, axis]) Join a sequence of arrays along an existing axis. Return the complex conjugate, element-wise. Return the complex conjugate, element-wise. convolve(a, v[, mode, precision]) Returns the discrete, linear convolution of two one-dimensional sequences. copysign(x1, x2) Change the sign of x1 to that of x2, element-wise. corrcoef(x[, y, rowvar]) Return Pearson product-moment correlation coefficients. correlate(a, v[, mode, precision]) Cross-correlation of two 1-dimensional sequences. Cosine element-wise. Hyperbolic cosine, element-wise. count_nonzero(a[, axis, keepdims]) Counts the number of non-zero values in the array a. cov(m[, y, rowvar, bias, ddof, fweights, …]) Estimate a covariance matrix, given data and weights. cross(a, b[, axisa, axisb, axisc, axis]) Return the cross product of two (arrays of) vectors. csingle alias of jax._src.numpy.lax_numpy.complex64 cumprod(a[, axis, dtype, out]) Return the cumulative product of elements along a given axis. cumproduct(a[, axis, dtype, out]) Return the cumulative product of elements along a given axis. cumsum(a[, axis, dtype, out]) Return the cumulative sum of the elements along a given axis. Convert angles from degrees to radians. Convert angles from radians to degrees. diag(v[, k]) Extract a diagonal or construct a diagonal array. diagflat(v[, k]) Create a two-dimensional array with the flattened input as a diagonal. diag_indices(n[, ndim]) Return the indices to access the main diagonal of an array. Return the indices to access the main diagonal of an n-dimensional array. diagonal(a[, offset, axis1, axis2]) Return specified diagonals. diff(a[, n, axis, prepend, append]) Calculate the n-th discrete difference along the given axis. digitize(x, bins[, right]) Return the indices of the bins to which each value in input array belongs. divide(x1, x2) Returns a true division of the inputs, element-wise. divmod(x1, x2) Return element-wise quotient and remainder simultaneously. dot(a, b, *[, precision]) Dot product of two arrays. double alias of jax._src.numpy.lax_numpy.float64 dsplit(ary, indices_or_sections) Split array into multiple sub-arrays along the 3rd axis (depth). dstack(tup) Stack arrays in sequence depth wise (along third axis). dtype(obj[, align, copy]) Create a data type object. ediff1d(ary[, to_end, to_begin]) The differences between consecutive elements of an array. einsum(*operands[, out, optimize, precision]) Evaluates the Einstein summation convention on the operands. einsum_path(subscripts, *operands[, optimize]) Evaluates the lowest cost contraction order for an einsum expression by considering the creation of intermediate arrays. empty(shape[, dtype]) Return a new array of given shape and type, filled with zeros. empty_like(a[, dtype, shape]) Return an array of zeros with the same shape and type as a given array. equal(x1, x2) Return (x1 == x2) element-wise. Calculate the exponential of all elements in the input array. Calculate 2**p for all p in the input array. expand_dims(a, axis) Expand the shape of an array. Calculate exp(x) - 1 for all elements in the array. extract(condition, arr) Return the elements of an array that satisfy some condition. eye(N[, M, k, dtype]) Return a 2-D array with ones on the diagonal and zeros elsewhere. Compute the absolute values element-wise. finfo(dtype) Machine limits for floating point types. fix(x[, out]) Round to nearest integer towards zero. Return indices that are non-zero in the flattened version of a. flexible Abstract base class of all scalar types without predefined length. flip(m[, axis]) Reverse the order of elements in an array along the given axis. Flip array in the left/right direction. Flip array in the up/down direction. float_ alias of jax._src.numpy.lax_numpy.float64 float16 float32 float64 floating Abstract base class of all floating-point scalar types. float_power(x1, x2) First array elements raised to powers from second array, element-wise. Return the floor of the input, element-wise. floor_divide(x1, x2) Return the largest integer smaller or equal to the division of the inputs. fmax(x1, x2) Element-wise maximum of array elements. fmin(x1, x2) Element-wise minimum of array elements. fmod(x1, x2) Return the element-wise remainder of division. Decompose the elements of x into mantissa and twos exponent. full(shape, fill_value[, dtype]) Return a new array of given shape and type, filled with fill_value. full_like(a, fill_value[, dtype, shape]) Return a full array with the same shape and type as a given array. gcd(x1, x2) Returns the greatest common divisor of |x1| and |x2| geomspace(start, stop[, num, endpoint, …]) Return numbers spaced evenly on a log scale (a geometric progression). gradient(f, *varargs[, axis, edge_order]) Return the gradient of an N-dimensional array. greater(x1, x2) Return the truth value of (x1 > x2) element-wise. greater_equal(x1, x2) Return the truth value of (x1 >= x2) element-wise. hamming(*args, **kwargs) Return the Hamming window. hanning(*args, **kwargs) Return the Hanning window. heaviside(x1, x2) Compute the Heaviside step function. histogram(a[, bins, range, weights, density]) Compute the histogram of a set of data. histogram_bin_edges(a[, bins, range, weights]) Function to calculate only the edges of the bins used by the histogram function. histogram2d(x, y[, bins, range, weights, …]) Compute the bi-dimensional histogram of two data samples. histogramdd(sample[, bins, range, weights, …]) Compute the multidimensional histogram of some data. hsplit(ary, indices_or_sections) Split an array into multiple sub-arrays horizontally (column-wise). hstack(tup) Stack arrays in sequence horizontally (column wise). hypot(x1, x2) Given the “legs” of a right triangle, return its hypotenuse. i0(x) Modified Bessel function of the first kind, order 0. identity(n[, dtype]) Return the identity array. iinfo(type) Machine limits for integer types. imag(val) Return the imaginary part of the complex argument. in1d(ar1, ar2[, assume_unique, invert]) Test whether each element of a 1-D array is also present in a second array. indices(dimensions[, dtype, sparse]) Return an array representing the indices of a grid. inexact Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floating-point numbers. inner(a, b, *[, precision]) Inner product of two arrays. int_ alias of jax._src.numpy.lax_numpy.int64 int16 int32 int64 int8 integer Abstract base class of all integer scalar types. interp(x, xp, fp[, left, right, period]) One-dimensional linear interpolation. intersect1d(ar1, ar2[, assume_unique, …]) Find the intersection of two arrays. Compute bit-wise inversion, or bit-wise NOT, element-wise. isclose(a, b[, rtol, atol, equal_nan]) Returns a boolean array where two arrays are element-wise equal within a tolerance. Returns a bool array, where True if input element is complex. Check for a complex type or an array of complex numbers. Test element-wise for finiteness (not infinity or not Not a Number). isin(element, test_elements[, …]) Calculates element in test_elements, broadcasting over element only. Test element-wise for positive or negative infinity. Test element-wise for NaN and return result as a boolean array. isneginf(x[, out]) Test element-wise for negative infinity, return result as bool array. isposinf(x[, out]) Test element-wise for positive infinity, return result as bool array. Returns a bool array, where True if input element is real. Return True if x is a not complex type or an array of complex numbers. isscalar(element) Returns True if the type of element is a scalar type. issubdtype(arg1, arg2) Returns True if first argument is a typecode lower/equal in type hierarchy. issubsctype(arg1, arg2) Determine if the first argument is a subclass of the second argument. Check whether or not an object can be iterated over. ix_(*args) Construct an open mesh from multiple sequences. kaiser(*args, **kwargs) Return the Kaiser window. kron(a, b) Kronecker product of two arrays. lcm(x1, x2) Returns the lowest common multiple of |x1| and |x2| ldexp(x1, x2) Returns x1 * 2**x2, element-wise. left_shift(x1, x2) Shift the bits of an integer to the left. less(x1, x2) Return the truth value of (x1 < x2) element-wise. less_equal(x1, x2) Return the truth value of (x1 =< x2) element-wise. lexsort(keys[, axis]) Perform an indirect stable sort using a sequence of keys. linspace(start, stop[, num, endpoint, …]) Return evenly spaced numbers over a specified interval. load(file[, mmap_mode, allow_pickle, …]) Load arrays or pickled objects from .npy, .npz or pickled files. Natural logarithm, element-wise. Return the base 10 logarithm of the input array, element-wise. Return the natural logarithm of one plus the input array, element-wise. Base-2 logarithm of x. logaddexp Logarithm of the sum of exponentiations of the inputs. logaddexp2 Logarithm of the sum of exponentiations of the inputs in base-2. logical_and(*args) Compute the truth value of x1 AND x2 element-wise. logical_not(*args) Compute the truth value of NOT x element-wise. logical_or(*args) Compute the truth value of x1 OR x2 element-wise. logical_xor(*args) Compute the truth value of x1 XOR x2, element-wise. logspace(start, stop[, num, endpoint, base, …]) Return numbers spaced evenly on a log scale. mask_indices(*args, **kwargs) Return the indices to access (n, n) arrays, given a masking function. matmul(a, b, *[, precision]) Matrix product of two arrays. max(a[, axis, out, keepdims, initial, where]) Return the maximum of an array or maximum along an axis. maximum(x1, x2) Element-wise maximum of array elements. mean(a[, axis, dtype, out, keepdims]) Compute the arithmetic mean along the specified axis. median(a[, axis, out, overwrite_input, keepdims]) Compute the median along the specified axis. meshgrid(*args, **kwargs) Return coordinate matrices from coordinate vectors. min(a[, axis, out, keepdims, initial, where]) Return the minimum of an array or minimum along an axis. minimum(x1, x2) Element-wise minimum of array elements. mod(x1, x2) Return element-wise remainder of division. modf(x[, out]) Return the fractional and integral parts of an array, element-wise. moveaxis(a, source, destination) Move axes of an array to new positions. Return a copy of an array sorted along the first axis. multiply(x1, x2) Multiply arguments element-wise. nanargmax(a[, axis]) Return the indices of the maximum values in the specified axis ignoring NaNs. nanargmin(a[, axis]) Return the indices of the minimum values in the specified axis ignoring NaNs. nancumprod(a[, axis, dtype, out]) Return the cumulative product of array elements over a given axis treating Not a Numbers (NaNs) as one. nancumsum(a[, axis, dtype, out]) Return the cumulative sum of array elements over a given axis treating Not a Numbers (NaNs) as zero. nanmax(a[, axis, out, keepdims]) Return the maximum of an array or maximum along an axis, ignoring any NaNs. nanmean(a[, axis, dtype, out, keepdims]) Compute the arithmetic mean along the specified axis, ignoring NaNs. nanmedian(a[, axis, out, overwrite_input, …]) Compute the median along the specified axis, while ignoring NaNs. nanmin(a[, axis, out, keepdims]) Return minimum of an array or minimum along an axis, ignoring any NaNs. nanpercentile(a, q[, axis, out, …]) Compute the qth percentile of the data along the specified axis, while ignoring nan values. nanprod(a[, axis, dtype, out, keepdims]) Return the product of array elements over a given axis treating Not a Numbers (NaNs) as ones. nanquantile(a, q[, axis, out, …]) Compute the qth quantile of the data along the specified axis, while ignoring nan values. nanstd(a[, axis, dtype, out, ddof, keepdims]) Compute the standard deviation along the specified axis, while ignoring NaNs. nansum(a[, axis, dtype, out, keepdims]) Return the sum of array elements over a given axis treating Not a Numbers (NaNs) as zero. nan_to_num(x[, copy, nan, posinf, neginf]) Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the nan, posinf and/or neginf keywords. nanvar(a[, axis, dtype, out, ddof, keepdims]) Compute the variance along the specified axis, while ignoring NaNs. ndarray([dtype, buffer, offset, strides, order]) Return the number of dimensions of an array. Numerical negative, element-wise. nextafter(x1, x2) Return the next floating-point value after x1 towards x2, element-wise. Return the indices of the elements that are non-zero. not_equal(x1, x2) Return (x1 != x2) element-wise. number Abstract base class of all numeric scalar types. object_ Any Python object. ones(shape[, dtype]) Return a new array of given shape and type, filled with ones. ones_like(a[, dtype, shape]) Return an array of ones with the same shape and type as a given array. outer(a, b[, out]) Compute the outer product of two vectors. packbits(a[, axis, bitorder]) Packs the elements of a binary-valued array into bits in a uint8 array. pad(array, pad_width[, mode]) Pad an array. percentile(a, q[, axis, out, …]) Compute the q-th percentile of the data along the specified axis. piecewise(x, condlist, funclist, *args, **kw) Evaluate a piecewise-defined function. polyadd(a1, a2) Find the sum of two polynomials. polyder(p[, m]) Return the derivative of the specified order of a polynomial. polymul(a1, a2, *[, trim_leading_zeros]) Find the product of two polynomials. polysub(a1, a2) Difference (subtraction) of two polynomials. polyval(p, x) Evaluate a polynomial at specific values. Numerical positive, element-wise. power(x1, x2) First array elements raised to powers from second array, element-wise. prod(a[, axis, dtype, out, keepdims, …]) Return the product of array elements over a given axis. product(a[, axis, dtype, out, keepdims, …]) Return the product of array elements over a given axis. promote_types(a, b) Returns the type to which a binary operation should cast its arguments. ptp(a[, axis, out, keepdims]) Range of values (maximum - minimum) along an axis. quantile(a, q[, axis, out, overwrite_input, …]) Compute the q-th quantile of the data along the specified axis. Convert angles from radians to degrees. Convert angles from degrees to radians. ravel(a[, order]) Return a contiguous flattened array. ravel_multi_index(multi_index, dims[, mode, …]) Converts a tuple of index arrays into an array of flat indices, applying boundary modes to the multi-index. real(val) Return the real part of the complex argument. Return the reciprocal of the argument, element-wise. remainder(x1, x2) Return element-wise remainder of division. repeat(a, repeats[, axis, total_repeat_length]) Repeat elements of an array. reshape(a, newshape[, order]) Gives a new shape to an array without changing its data. result_type(*args) Returns the type that results from applying the NumPy type promotion rules to the arguments. right_shift(x1, x2) Shift the bits of an integer to the right. Round elements of the array to the nearest integer. roll(a, shift[, axis]) Roll array elements along a given axis. rollaxis(a, axis[, start]) Roll the specified axis backwards, until it lies in a given position. roots(p, *[, strip_zeros]) Return the roots of a polynomial with coefficients given in p. rot90(m[, k, axes]) Rotate an array by 90 degrees in the plane specified by axes. round(a[, decimals, out]) Round an array to the given number of decimals. row_stack(tup) Stack arrays in sequence vertically (row wise). save(file, arr[, allow_pickle, fix_imports]) Save an array to a binary file in NumPy .npy format. savez(file, *args, **kwds) Save several arrays into a single file in uncompressed .npz format. searchsorted(a, v[, side, sorter]) Find indices where elements should be inserted to maintain order. select(condlist, choicelist[, default]) Return an array drawn from elements in choicelist, depending on conditions. set_printoptions([precision, threshold, …]) Set printing options. setdiff1d(ar1, ar2[, assume_unique]) Find the set difference of two arrays. Return the shape of an array. Returns an element-wise indication of the sign of a number. Returns element-wise True where signbit is set (less than zero). signedinteger Abstract base class of all signed integer scalar types. Trigonometric sine, element-wise. Return the sinc function. single alias of jax._src.numpy.lax_numpy.float32 Hyperbolic sine, element-wise. size(a[, axis]) Return the number of elements along a given axis. sometrue(a[, axis, out, keepdims]) Test whether any array element along a given axis evaluates to True. sort(a[, axis, kind, order]) Return a sorted copy of an array. Sort a complex array using the real part first, then the imaginary part. split(ary, indices_or_sections[, axis]) Split an array into multiple sub-arrays as views into ary. Return the non-negative square-root of an array, element-wise. Return the element-wise square of the input. squeeze(a[, axis]) Remove single-dimensional entries from the shape of an array. stack(arrays[, axis, out]) Join a sequence of arrays along a new axis. std(a[, axis, dtype, out, ddof, keepdims]) Compute the standard deviation along the specified axis. subtract(x1, x2) Subtract arguments, element-wise. sum(a[, axis, dtype, out, keepdims, …]) Sum of array elements over a given axis. swapaxes(a, axis1, axis2) Interchange two axes of an array. take(a, indices[, axis, out, mode]) Take elements from an array along an axis. take_along_axis(arr, indices, axis) Take values from the input array by matching 1d index and data slices. Compute tangent element-wise. Compute hyperbolic tangent element-wise. tensordot(a, b[, axes, precision]) Compute tensor dot product along specified axes. tile(A, reps) Construct an array by repeating A the number of times given by reps. trace(a[, offset, axis1, axis2, dtype, out]) Return the sum along diagonals of the array. transpose(a[, axes]) Reverse or permute the axes of an array; returns the modified array. trapz(y[, x, dx, axis]) Integrate along the given axis using the composite trapezoidal rule. tri(N[, M, k, dtype]) An array with ones at and below the given diagonal and zeros elsewhere. tril(m[, k]) Lower triangle of an array. tril_indices(*args, **kwargs) Return the indices for the lower-triangle of an (n, m) array. tril_indices_from(arr[, k]) Return the indices for the lower-triangle of arr. trim_zeros(filt[, trim]) Trim the leading and/or trailing zeros from a 1-D array or sequence. triu(m[, k]) Upper triangle of an array. triu_indices(*args, **kwargs) Return the indices for the upper-triangle of an (n, m) array. triu_indices_from(arr[, k]) Return the indices for the upper-triangle of arr. true_divide(x1, x2) Returns a true division of the inputs, element-wise. Return the truncated value of the input, element-wise. uint16 uint32 uint64 uint8 unique(ar[, return_index, return_inverse, …]) Find the unique elements of an array. unpackbits(a[, axis, count, bitorder]) Unpacks elements of a uint8 array into a binary-valued output array. unravel_index(indices, shape) Converts a flat index or array of flat indices into a tuple of coordinate arrays. unsignedinteger Abstract base class of all unsigned integer scalar types. unwrap(p[, discont, axis]) Unwrap by changing deltas between values to 2*pi complement. vander(x[, N, increasing]) Generate a Vandermonde matrix. var(a[, axis, dtype, out, ddof, keepdims]) Compute the variance along the specified axis. vdot(a, b, *[, precision]) Return the dot product of two vectors. vectorize(pyfunc, *[, excluded, signature]) Define a vectorized function with broadcasting. vsplit(ary, indices_or_sections) Split an array into multiple sub-arrays vertically (row-wise). vstack(tup) Stack arrays in sequence vertically (row wise). where(condition[, x, y]) Return elements chosen from x or y depending on condition. zeros(shape[, dtype]) Return a new array of given shape and type, filled with zeros. zeros_like(a[, dtype, shape]) Return an array of zeros with the same shape and type as a given array.

## jax.numpy.fft¶

 fft(a[, n, axis, norm]) Compute the one-dimensional discrete Fourier Transform. fft2(a[, s, axes, norm]) Compute the 2-dimensional discrete Fourier Transform fftfreq(n[, d]) Return the Discrete Fourier Transform sample frequencies. fftn(a[, s, axes, norm]) Compute the N-dimensional discrete Fourier Transform. fftshift(x[, axes]) Shift the zero-frequency component to the center of the spectrum. hfft(a[, n, axis, norm]) Compute the FFT of a signal that has Hermitian symmetry, i.e., a real spectrum. ifft(a[, n, axis, norm]) Compute the one-dimensional inverse discrete Fourier Transform. ifft2(a[, s, axes, norm]) Compute the 2-dimensional inverse discrete Fourier Transform. ifftn(a[, s, axes, norm]) Compute the N-dimensional inverse discrete Fourier Transform. ifftshift(x[, axes]) The inverse of fftshift. ihfft(a[, n, axis, norm]) Compute the inverse FFT of a signal that has Hermitian symmetry. irfft(a[, n, axis, norm]) Compute the inverse of the n-point DFT for real input. irfft2(a[, s, axes, norm]) Compute the 2-dimensional inverse FFT of a real array. irfftn(a[, s, axes, norm]) Compute the inverse of the N-dimensional FFT of real input. rfft(a[, n, axis, norm]) Compute the one-dimensional discrete Fourier Transform for real input. rfft2(a[, s, axes, norm]) Compute the 2-dimensional FFT of a real array. rfftfreq(n[, d]) Return the Discrete Fourier Transform sample frequencies (for usage with rfft, irfft). rfftn(a[, s, axes, norm]) Compute the N-dimensional discrete Fourier Transform for real input.

## jax.numpy.linalg¶

 Cholesky decomposition. cond(x[, p]) Compute the condition number of a matrix. det Compute the determinant of an array. Compute the eigenvalues and right eigenvectors of a square array. eigh(a[, UPLO, symmetrize_input]) Return the eigenvalues and eigenvectors of a complex Hermitian (conjugate symmetric) or a real symmetric matrix. Compute the eigenvalues of a general matrix. eigvalsh(a[, UPLO]) Compute the eigenvalues of a complex Hermitian or real symmetric matrix. Compute the (multiplicative) inverse of a matrix. lstsq(a, b[, rcond, numpy_resid]) Return the least-squares solution to a linear matrix equation. matrix_power(a, n) Raise a square matrix to the (integer) power n. matrix_rank(M[, tol]) Return matrix rank of array using SVD method multi_dot(arrays, *[, precision]) Compute the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order. norm(x[, ord, axis, keepdims]) Matrix or vector norm. pinv Compute the (Moore-Penrose) pseudo-inverse of a matrix. qr(a[, mode]) Compute the qr factorization of a matrix. slogdet Compute the sign and (natural) logarithm of the determinant of an array. solve(a, b) Solve a linear matrix equation, or system of linear scalar equations. svd(a[, full_matrices, compute_uv]) Singular Value Decomposition. tensorinv(a[, ind]) Compute the ‘inverse’ of an N-dimensional array. tensorsolve(a, b[, axes]) Solve the tensor equation a x = b for x.