jax.numpy package¶

Implements the NumPy API, using the primitives in jax.lax.

While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.

  • Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place cannot be implemented in JAX. However, often JAX is able to provide a alternative API that is purely functional. For example, instead of in-place array updates (x[i] = y), JAX provides an alternative pure indexed update function jax.ops.index_update().

  • NumPy is very aggressive at promoting values to float64 type. JAX sometimes is less aggressive about type promotion.

A small number of NumPy operations that have data-dependent output shapes are incompatible with jax.jit() compilation. The XLA compiler requires that shapes of arrays be known at compile time. While it would be possible to provide a JAX implementation of an API such as numpy.nonzero(), we would be unable to JIT-compile it because the shape of its output depends on the contents of the input data.

Not every function in NumPy is implemented; contributions are welcome!

abs(x)

Calculate the absolute value element-wise.

absolute(x)

Calculate the absolute value element-wise.

add(x1, x2)

Add arguments element-wise.

all(a[, axis, out, keepdims])

Test whether all array elements along a given axis evaluate to True.

allclose(a, b[, rtol, atol, equal_nan])

Returns True if two arrays are element-wise equal within a tolerance.

alltrue(a[, axis, out, keepdims])

Test whether all array elements along a given axis evaluate to True.

amax(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

amin(a[, axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

angle(z)

Return the angle of the complex argument.

any(a[, axis, out, keepdims])

Test whether any array element along a given axis evaluates to True.

append(arr, values[, axis])

Append values to the end of an array.

apply_along_axis(func1d, axis, arr, *args, …)

Apply a function to 1-D slices along the given axis.

apply_over_axes(func, a, axes)

Apply a function repeatedly over multiple axes.

arange(start[, stop, step, dtype])

Return evenly spaced values within a given interval.

arccos(x)

Trigonometric inverse cosine, element-wise.

arccosh(x)

Inverse hyperbolic cosine, element-wise.

arcsin(x)

Inverse sine, element-wise.

arcsinh(x)

Inverse hyperbolic sine element-wise.

arctan(x)

Trigonometric inverse tangent, element-wise.

arctan2(x1, x2)

Element-wise arc tangent of x1/x2 choosing the quadrant correctly.

arctanh(x)

Inverse hyperbolic tangent element-wise.

argmax(a[, axis, out])

Returns the indices of the maximum values along an axis.

argmin(a[, axis, out])

Returns the indices of the minimum values along an axis.

argsort(a[, axis, kind, order])

Returns the indices that would sort an array.

argwhere(a)

Find the indices of array elements that are non-zero, grouped by element.

around(a[, decimals, out])

Round an array to the given number of decimals.

array(object[, dtype, copy, order, ndmin])

Create an array.

array_equal(a1, a2[, equal_nan])

True if two arrays have the same shape and elements, False otherwise.

array_equiv(a1, a2)

Returns True if input arrays are shape consistent and all elements equal.

array_repr(arr[, max_line_width, precision, …])

Return the string representation of an array.

array_split(ary, indices_or_sections[, axis])

Split an array into multiple sub-arrays.

array_str(a[, max_line_width, precision, …])

Return a string representation of the data in an array.

asarray(a[, dtype, order])

Convert the input to an array.

atleast_1d(*arys)

Convert inputs to arrays with at least one dimension.

atleast_2d(*arys)

View inputs as arrays with at least two dimensions.

atleast_3d(*arys)

View inputs as arrays with at least three dimensions.

average(a[, axis, weights, returned])

Compute the weighted average along the specified axis.

bartlett(*args, **kwargs)

Return the Bartlett window.

bincount(x[, weights, minlength, length])

Count number of occurrences of each value in array of non-negative ints.

bitwise_and(x1, x2)

Compute the bit-wise AND of two arrays element-wise.

bitwise_not(x)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

bitwise_or(x1, x2)

Compute the bit-wise OR of two arrays element-wise.

bitwise_xor(x1, x2)

Compute the bit-wise XOR of two arrays element-wise.

blackman(*args, **kwargs)

Return the Blackman window.

block(arrays)

Assemble an nd-array from nested lists of blocks.

bool_

broadcast_arrays(*args)

Like Numpy’s broadcast_arrays but doesn’t return views.

broadcast_to(arr, shape)

Broadcast an array to a new shape.

can_cast(from_, to[, casting])

Returns True if cast between data types can occur according to the casting rule.

cbrt(x)

Return the cube-root of an array, element-wise.

cdouble

alias of jax._src.numpy.lax_numpy.complex128

ceil(x)

Return the ceiling of the input, element-wise.

character

Abstract base class of all character string scalar types.

choose(a, choices[, out, mode])

Construct an array from an index array and a set of arrays to choose from.

clip(a[, a_min, a_max, out])

Clip (limit) the values in an array.

column_stack(tup)

Stack 1-D arrays as columns into a 2-D array.

complex_

alias of jax._src.numpy.lax_numpy.complex128

complex128

complex64

complexfloating

Abstract base class of all complex number scalar types that are made up of floating-point numbers.

ComplexWarning

The warning raised when casting a complex dtype to a real dtype.

compress(condition, a[, axis, out])

Return selected slices of an array along given axis.

concatenate(arrays[, axis])

Join a sequence of arrays along an existing axis.

conj(x)

Return the complex conjugate, element-wise.

conjugate(x)

Return the complex conjugate, element-wise.

convolve(a, v[, mode, precision])

Returns the discrete, linear convolution of two one-dimensional sequences.

copysign(x1, x2)

Change the sign of x1 to that of x2, element-wise.

corrcoef(x[, y, rowvar])

Return Pearson product-moment correlation coefficients.

correlate(a, v[, mode, precision])

Cross-correlation of two 1-dimensional sequences.

cos(x)

Cosine element-wise.

cosh(x)

Hyperbolic cosine, element-wise.

count_nonzero(a[, axis, keepdims])

Counts the number of non-zero values in the array a.

cov(m[, y, rowvar, bias, ddof, fweights, …])

Estimate a covariance matrix, given data and weights.

cross(a, b[, axisa, axisb, axisc, axis])

Return the cross product of two (arrays of) vectors.

csingle

alias of jax._src.numpy.lax_numpy.complex64

cumprod(a[, axis, dtype, out])

Return the cumulative product of elements along a given axis.

cumproduct(a[, axis, dtype, out])

Return the cumulative product of elements along a given axis.

cumsum(a[, axis, dtype, out])

Return the cumulative sum of the elements along a given axis.

deg2rad(x)

Convert angles from degrees to radians.

degrees(x)

Convert angles from radians to degrees.

diag(v[, k])

Extract a diagonal or construct a diagonal array.

diagflat(v[, k])

Create a two-dimensional array with the flattened input as a diagonal.

diag_indices(n[, ndim])

Return the indices to access the main diagonal of an array.

diag_indices_from(arr)

Return the indices to access the main diagonal of an n-dimensional array.

diagonal(a[, offset, axis1, axis2])

Return specified diagonals.

diff(a[, n, axis, prepend, append])

Calculate the n-th discrete difference along the given axis.

digitize(x, bins[, right])

Return the indices of the bins to which each value in input array belongs.

divide(x1, x2)

Returns a true division of the inputs, element-wise.

divmod(x1, x2)

Return element-wise quotient and remainder simultaneously.

dot(a, b, *[, precision])

Dot product of two arrays.

double

alias of jax._src.numpy.lax_numpy.float64

dsplit(ary, indices_or_sections)

Split array into multiple sub-arrays along the 3rd axis (depth).

dstack(tup)

Stack arrays in sequence depth wise (along third axis).

dtype(obj[, align, copy])

Create a data type object.

ediff1d(ary[, to_end, to_begin])

The differences between consecutive elements of an array.

einsum(*operands[, out, optimize, precision])

Evaluates the Einstein summation convention on the operands.

einsum_path(subscripts, *operands[, optimize])

Evaluates the lowest cost contraction order for an einsum expression by considering the creation of intermediate arrays.

empty(shape[, dtype])

Return a new array of given shape and type, filled with zeros.

empty_like(a[, dtype, shape])

Return an array of zeros with the same shape and type as a given array.

equal(x1, x2)

Return (x1 == x2) element-wise.

exp(x)

Calculate the exponential of all elements in the input array.

exp2(x)

Calculate 2**p for all p in the input array.

expand_dims(a, axis)

Expand the shape of an array.

expm1(x)

Calculate exp(x) - 1 for all elements in the array.

extract(condition, arr)

Return the elements of an array that satisfy some condition.

eye(N[, M, k, dtype])

Return a 2-D array with ones on the diagonal and zeros elsewhere.

fabs(x)

Compute the absolute values element-wise.

finfo(dtype)

Machine limits for floating point types.

fix(x[, out])

Round to nearest integer towards zero.

flatnonzero(a)

Return indices that are non-zero in the flattened version of a.

flexible

Abstract base class of all scalar types without predefined length.

flip(m[, axis])

Reverse the order of elements in an array along the given axis.

fliplr(m)

Flip array in the left/right direction.

flipud(m)

Flip array in the up/down direction.

float_

alias of jax._src.numpy.lax_numpy.float64

float16

float32

float64

floating

Abstract base class of all floating-point scalar types.

float_power(x1, x2)

First array elements raised to powers from second array, element-wise.

floor(x)

Return the floor of the input, element-wise.

floor_divide(x1, x2)

Return the largest integer smaller or equal to the division of the inputs.

fmax(x1, x2)

Element-wise maximum of array elements.

fmin(x1, x2)

Element-wise minimum of array elements.

fmod(x1, x2)

Return the element-wise remainder of division.

frexp(x)

Decompose the elements of x into mantissa and twos exponent.

full(shape, fill_value[, dtype])

Return a new array of given shape and type, filled with fill_value.

full_like(a, fill_value[, dtype, shape])

Return a full array with the same shape and type as a given array.

gcd(x1, x2)

Returns the greatest common divisor of |x1| and |x2|

geomspace(start, stop[, num, endpoint, …])

Return numbers spaced evenly on a log scale (a geometric progression).

gradient(f, *varargs[, axis, edge_order])

Return the gradient of an N-dimensional array.

greater(x1, x2)

Return the truth value of (x1 > x2) element-wise.

greater_equal(x1, x2)

Return the truth value of (x1 >= x2) element-wise.

hamming(*args, **kwargs)

Return the Hamming window.

hanning(*args, **kwargs)

Return the Hanning window.

heaviside(x1, x2)

Compute the Heaviside step function.

histogram(a[, bins, range, weights, density])

Compute the histogram of a set of data.

histogram_bin_edges(a[, bins, range, weights])

Function to calculate only the edges of the bins used by the histogram function.

histogram2d(x, y[, bins, range, weights, …])

Compute the bi-dimensional histogram of two data samples.

histogramdd(sample[, bins, range, weights, …])

Compute the multidimensional histogram of some data.

hsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays horizontally (column-wise).

hstack(tup)

Stack arrays in sequence horizontally (column wise).

hypot(x1, x2)

Given the ‚Äúlegs‚ÄĚ of a right triangle, return its hypotenuse.

i0(x)

Modified Bessel function of the first kind, order 0.

identity(n[, dtype])

Return the identity array.

iinfo(type)

Machine limits for integer types.

imag(val)

Return the imaginary part of the complex argument.

in1d(ar1, ar2[, assume_unique, invert])

Test whether each element of a 1-D array is also present in a second array.

indices(dimensions[, dtype, sparse])

Return an array representing the indices of a grid.

inexact

Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floating-point numbers.

inner(a, b, *[, precision])

Inner product of two arrays.

int_

alias of jax._src.numpy.lax_numpy.int64

int16

int32

int64

int8

integer

Abstract base class of all integer scalar types.

interp(x, xp, fp[, left, right, period])

One-dimensional linear interpolation.

intersect1d(ar1, ar2[, assume_unique, …])

Find the intersection of two arrays.

invert(x)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

isclose(a, b[, rtol, atol, equal_nan])

Returns a boolean array where two arrays are element-wise equal within a tolerance.

iscomplex(x)

Returns a bool array, where True if input element is complex.

iscomplexobj(x)

Check for a complex type or an array of complex numbers.

isfinite(x)

Test element-wise for finiteness (not infinity or not Not a Number).

isin(element, test_elements[, …])

Calculates element in test_elements, broadcasting over element only.

isinf(x)

Test element-wise for positive or negative infinity.

isnan(x)

Test element-wise for NaN and return result as a boolean array.

isneginf(x[, out])

Test element-wise for negative infinity, return result as bool array.

isposinf(x[, out])

Test element-wise for positive infinity, return result as bool array.

isreal(x)

Returns a bool array, where True if input element is real.

isrealobj(x)

Return True if x is a not complex type or an array of complex numbers.

isscalar(element)

Returns True if the type of element is a scalar type.

issubdtype(arg1, arg2)

Returns True if first argument is a typecode lower/equal in type hierarchy.

issubsctype(arg1, arg2)

Determine if the first argument is a subclass of the second argument.

iterable(y)

Check whether or not an object can be iterated over.

ix_(*args)

Construct an open mesh from multiple sequences.

kaiser(*args, **kwargs)

Return the Kaiser window.

kron(a, b)

Kronecker product of two arrays.

lcm(x1, x2)

Returns the lowest common multiple of |x1| and |x2|

ldexp(x1, x2)

Returns x1 * 2**x2, element-wise.

left_shift(x1, x2)

Shift the bits of an integer to the left.

less(x1, x2)

Return the truth value of (x1 < x2) element-wise.

less_equal(x1, x2)

Return the truth value of (x1 =< x2) element-wise.

lexsort(keys[, axis])

Perform an indirect stable sort using a sequence of keys.

linspace(start, stop[, num, endpoint, …])

Return evenly spaced numbers over a specified interval.

load(file[, mmap_mode, allow_pickle, …])

Load arrays or pickled objects from .npy, .npz or pickled files.

log(x)

Natural logarithm, element-wise.

log10(x)

Return the base 10 logarithm of the input array, element-wise.

log1p(x)

Return the natural logarithm of one plus the input array, element-wise.

log2(x)

Base-2 logarithm of x.

logaddexp

Logarithm of the sum of exponentiations of the inputs.

logaddexp2

Logarithm of the sum of exponentiations of the inputs in base-2.

logical_and(*args)

Compute the truth value of x1 AND x2 element-wise.

logical_not(*args)

Compute the truth value of NOT x element-wise.

logical_or(*args)

Compute the truth value of x1 OR x2 element-wise.

logical_xor(*args)

Compute the truth value of x1 XOR x2, element-wise.

logspace(start, stop[, num, endpoint, base, …])

Return numbers spaced evenly on a log scale.

mask_indices(*args, **kwargs)

Return the indices to access (n, n) arrays, given a masking function.

matmul(a, b, *[, precision])

Matrix product of two arrays.

max(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

maximum(x1, x2)

Element-wise maximum of array elements.

mean(a[, axis, dtype, out, keepdims])

Compute the arithmetic mean along the specified axis.

median(a[, axis, out, overwrite_input, keepdims])

Compute the median along the specified axis.

meshgrid(*args, **kwargs)

Return coordinate matrices from coordinate vectors.

min(a[, axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

minimum(x1, x2)

Element-wise minimum of array elements.

mod(x1, x2)

Return element-wise remainder of division.

modf(x[, out])

Return the fractional and integral parts of an array, element-wise.

moveaxis(a, source, destination)

Move axes of an array to new positions.

msort(a)

Return a copy of an array sorted along the first axis.

multiply(x1, x2)

Multiply arguments element-wise.

nanargmax(a[, axis])

Return the indices of the maximum values in the specified axis ignoring NaNs.

nanargmin(a[, axis])

Return the indices of the minimum values in the specified axis ignoring NaNs.

nancumprod(a[, axis, dtype, out])

Return the cumulative product of array elements over a given axis treating Not a Numbers (NaNs) as one.

nancumsum(a[, axis, dtype, out])

Return the cumulative sum of array elements over a given axis treating Not a Numbers (NaNs) as zero.

nanmax(a[, axis, out, keepdims])

Return the maximum of an array or maximum along an axis, ignoring any NaNs.

nanmean(a[, axis, dtype, out, keepdims])

Compute the arithmetic mean along the specified axis, ignoring NaNs.

nanmedian(a[, axis, out, overwrite_input, …])

Compute the median along the specified axis, while ignoring NaNs.

nanmin(a[, axis, out, keepdims])

Return minimum of an array or minimum along an axis, ignoring any NaNs.

nanpercentile(a, q[, axis, out, …])

Compute the qth percentile of the data along the specified axis, while ignoring nan values.

nanprod(a[, axis, dtype, out, keepdims])

Return the product of array elements over a given axis treating Not a Numbers (NaNs) as ones.

nanquantile(a, q[, axis, out, …])

Compute the qth quantile of the data along the specified axis, while ignoring nan values.

nanstd(a[, axis, dtype, out, ddof, keepdims])

Compute the standard deviation along the specified axis, while ignoring NaNs.

nansum(a[, axis, dtype, out, keepdims])

Return the sum of array elements over a given axis treating Not a Numbers (NaNs) as zero.

nan_to_num(x[, copy, nan, posinf, neginf])

Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the nan, posinf and/or neginf keywords.

nanvar(a[, axis, dtype, out, ddof, keepdims])

Compute the variance along the specified axis, while ignoring NaNs.

ndarray([dtype, buffer, offset, strides, order])

ndim(a)

Return the number of dimensions of an array.

negative(x)

Numerical negative, element-wise.

nextafter(x1, x2)

Return the next floating-point value after x1 towards x2, element-wise.

nonzero(a)

Return the indices of the elements that are non-zero.

not_equal(x1, x2)

Return (x1 != x2) element-wise.

number

Abstract base class of all numeric scalar types.

object_

Any Python object.

ones(shape[, dtype])

Return a new array of given shape and type, filled with ones.

ones_like(a[, dtype, shape])

Return an array of ones with the same shape and type as a given array.

outer(a, b[, out])

Compute the outer product of two vectors.

packbits(a[, axis, bitorder])

Packs the elements of a binary-valued array into bits in a uint8 array.

pad(array, pad_width[, mode])

Pad an array.

percentile(a, q[, axis, out, …])

Compute the q-th percentile of the data along the specified axis.

piecewise(x, condlist, funclist, *args, **kw)

Evaluate a piecewise-defined function.

polyadd(a1, a2)

Find the sum of two polynomials.

polyder(p[, m])

Return the derivative of the specified order of a polynomial.

polymul(a1, a2, *[, trim_leading_zeros])

Find the product of two polynomials.

polysub(a1, a2)

Difference (subtraction) of two polynomials.

polyval(p, x)

Evaluate a polynomial at specific values.

positive(x)

Numerical positive, element-wise.

power(x1, x2)

First array elements raised to powers from second array, element-wise.

prod(a[, axis, dtype, out, keepdims, …])

Return the product of array elements over a given axis.

product(a[, axis, dtype, out, keepdims, …])

Return the product of array elements over a given axis.

promote_types(a, b)

Returns the type to which a binary operation should cast its arguments.

ptp(a[, axis, out, keepdims])

Range of values (maximum - minimum) along an axis.

quantile(a, q[, axis, out, overwrite_input, …])

Compute the q-th quantile of the data along the specified axis.

rad2deg(x)

Convert angles from radians to degrees.

radians(x)

Convert angles from degrees to radians.

ravel(a[, order])

Return a contiguous flattened array.

ravel_multi_index(multi_index, dims[, mode, …])

Converts a tuple of index arrays into an array of flat indices, applying boundary modes to the multi-index.

real(val)

Return the real part of the complex argument.

reciprocal(x)

Return the reciprocal of the argument, element-wise.

remainder(x1, x2)

Return element-wise remainder of division.

repeat(a, repeats[, axis, total_repeat_length])

Repeat elements of an array.

reshape(a, newshape[, order])

Gives a new shape to an array without changing its data.

result_type(*args)

Returns the type that results from applying the NumPy type promotion rules to the arguments.

right_shift(x1, x2)

Shift the bits of an integer to the right.

rint(x)

Round elements of the array to the nearest integer.

roll(a, shift[, axis])

Roll array elements along a given axis.

rollaxis(a, axis[, start])

Roll the specified axis backwards, until it lies in a given position.

roots(p, *[, strip_zeros])

Return the roots of a polynomial with coefficients given in p.

rot90(m[, k, axes])

Rotate an array by 90 degrees in the plane specified by axes.

round(a[, decimals, out])

Round an array to the given number of decimals.

row_stack(tup)

Stack arrays in sequence vertically (row wise).

save(file, arr[, allow_pickle, fix_imports])

Save an array to a binary file in NumPy .npy format.

savez(file, *args, **kwds)

Save several arrays into a single file in uncompressed .npz format.

searchsorted(a, v[, side, sorter])

Find indices where elements should be inserted to maintain order.

select(condlist, choicelist[, default])

Return an array drawn from elements in choicelist, depending on conditions.

set_printoptions([precision, threshold, …])

Set printing options.

setdiff1d(ar1, ar2[, assume_unique])

Find the set difference of two arrays.

shape(a)

Return the shape of an array.

sign(x)

Returns an element-wise indication of the sign of a number.

signbit(x)

Returns element-wise True where signbit is set (less than zero).

signedinteger

Abstract base class of all signed integer scalar types.

sin(x)

Trigonometric sine, element-wise.

sinc(x)

Return the sinc function.

single

alias of jax._src.numpy.lax_numpy.float32

sinh(x)

Hyperbolic sine, element-wise.

size(a[, axis])

Return the number of elements along a given axis.

sometrue(a[, axis, out, keepdims])

Test whether any array element along a given axis evaluates to True.

sort(a[, axis, kind, order])

Return a sorted copy of an array.

sort_complex(a)

Sort a complex array using the real part first, then the imaginary part.

split(ary, indices_or_sections[, axis])

Split an array into multiple sub-arrays as views into ary.

sqrt(x)

Return the non-negative square-root of an array, element-wise.

square(x)

Return the element-wise square of the input.

squeeze(a[, axis])

Remove single-dimensional entries from the shape of an array.

stack(arrays[, axis, out])

Join a sequence of arrays along a new axis.

std(a[, axis, dtype, out, ddof, keepdims])

Compute the standard deviation along the specified axis.

subtract(x1, x2)

Subtract arguments, element-wise.

sum(a[, axis, dtype, out, keepdims, …])

Sum of array elements over a given axis.

swapaxes(a, axis1, axis2)

Interchange two axes of an array.

take(a, indices[, axis, out, mode])

Take elements from an array along an axis.

take_along_axis(arr, indices, axis)

Take values from the input array by matching 1d index and data slices.

tan(x)

Compute tangent element-wise.

tanh(x)

Compute hyperbolic tangent element-wise.

tensordot(a, b[, axes, precision])

Compute tensor dot product along specified axes.

tile(A, reps)

Construct an array by repeating A the number of times given by reps.

trace(a[, offset, axis1, axis2, dtype, out])

Return the sum along diagonals of the array.

transpose(a[, axes])

Reverse or permute the axes of an array; returns the modified array.

trapz(y[, x, dx, axis])

Integrate along the given axis using the composite trapezoidal rule.

tri(N[, M, k, dtype])

An array with ones at and below the given diagonal and zeros elsewhere.

tril(m[, k])

Lower triangle of an array.

tril_indices(*args, **kwargs)

Return the indices for the lower-triangle of an (n, m) array.

tril_indices_from(arr[, k])

Return the indices for the lower-triangle of arr.

trim_zeros(filt[, trim])

Trim the leading and/or trailing zeros from a 1-D array or sequence.

triu(m[, k])

Upper triangle of an array.

triu_indices(*args, **kwargs)

Return the indices for the upper-triangle of an (n, m) array.

triu_indices_from(arr[, k])

Return the indices for the upper-triangle of arr.

true_divide(x1, x2)

Returns a true division of the inputs, element-wise.

trunc(x)

Return the truncated value of the input, element-wise.

uint16

uint32

uint64

uint8

unique(ar[, return_index, return_inverse, …])

Find the unique elements of an array.

unpackbits(a[, axis, count, bitorder])

Unpacks elements of a uint8 array into a binary-valued output array.

unravel_index(indices, shape)

Converts a flat index or array of flat indices into a tuple of coordinate arrays.

unsignedinteger

Abstract base class of all unsigned integer scalar types.

unwrap(p[, discont, axis])

Unwrap by changing deltas between values to 2*pi complement.

vander(x[, N, increasing])

Generate a Vandermonde matrix.

var(a[, axis, dtype, out, ddof, keepdims])

Compute the variance along the specified axis.

vdot(a, b, *[, precision])

Return the dot product of two vectors.

vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

vsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays vertically (row-wise).

vstack(tup)

Stack arrays in sequence vertically (row wise).

where(condition[, x, y])

Return elements chosen from x or y depending on condition.

zeros(shape[, dtype])

Return a new array of given shape and type, filled with zeros.

zeros_like(a[, dtype, shape])

Return an array of zeros with the same shape and type as a given array.

jax.numpy.fft¶

fft(a[, n, axis, norm])

Compute the one-dimensional discrete Fourier Transform.

fft2(a[, s, axes, norm])

Compute the 2-dimensional discrete Fourier Transform

fftfreq(n[, d])

Return the Discrete Fourier Transform sample frequencies.

fftn(a[, s, axes, norm])

Compute the N-dimensional discrete Fourier Transform.

fftshift(x[, axes])

Shift the zero-frequency component to the center of the spectrum.

hfft(a[, n, axis, norm])

Compute the FFT of a signal that has Hermitian symmetry, i.e., a real spectrum.

ifft(a[, n, axis, norm])

Compute the one-dimensional inverse discrete Fourier Transform.

ifft2(a[, s, axes, norm])

Compute the 2-dimensional inverse discrete Fourier Transform.

ifftn(a[, s, axes, norm])

Compute the N-dimensional inverse discrete Fourier Transform.

ifftshift(x[, axes])

The inverse of fftshift.

ihfft(a[, n, axis, norm])

Compute the inverse FFT of a signal that has Hermitian symmetry.

irfft(a[, n, axis, norm])

Compute the inverse of the n-point DFT for real input.

irfft2(a[, s, axes, norm])

Compute the 2-dimensional inverse FFT of a real array.

irfftn(a[, s, axes, norm])

Compute the inverse of the N-dimensional FFT of real input.

rfft(a[, n, axis, norm])

Compute the one-dimensional discrete Fourier Transform for real input.

rfft2(a[, s, axes, norm])

Compute the 2-dimensional FFT of a real array.

rfftfreq(n[, d])

Return the Discrete Fourier Transform sample frequencies (for usage with rfft, irfft).

rfftn(a[, s, axes, norm])

Compute the N-dimensional discrete Fourier Transform for real input.

jax.numpy.linalg¶

cholesky(a)

Cholesky decomposition.

cond(x[, p])

Compute the condition number of a matrix.

det

Compute the determinant of an array.

eig(a)

Compute the eigenvalues and right eigenvectors of a square array.

eigh(a[, UPLO, symmetrize_input])

Return the eigenvalues and eigenvectors of a complex Hermitian (conjugate symmetric) or a real symmetric matrix.

eigvals(a)

Compute the eigenvalues of a general matrix.

eigvalsh(a[, UPLO])

Compute the eigenvalues of a complex Hermitian or real symmetric matrix.

inv(a)

Compute the (multiplicative) inverse of a matrix.

lstsq(a, b[, rcond, numpy_resid])

Return the least-squares solution to a linear matrix equation.

matrix_power(a, n)

Raise a square matrix to the (integer) power n.

matrix_rank(M[, tol])

Return matrix rank of array using SVD method

multi_dot(arrays, *[, precision])

Compute the dot product of two or more arrays in a single function call, while automatically selecting the fastest evaluation order.

norm(x[, ord, axis, keepdims])

Matrix or vector norm.

pinv

Compute the (Moore-Penrose) pseudo-inverse of a matrix.

qr(a[, mode])

Compute the qr factorization of a matrix.

slogdet

Compute the sign and (natural) logarithm of the determinant of an array.

solve(a, b)

Solve a linear matrix equation, or system of linear scalar equations.

svd(a[, full_matrices, compute_uv])

Singular Value Decomposition.

tensorinv(a[, ind])

Compute the ‚Äėinverse‚Äô of an N-dimensional array.

tensorsolve(a, b[, axes])

Solve the tensor equation a x = b for x.