jax.numpy package

Implements the NumPy API, using the primitives in jax.lax.

While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.

  • Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place cannot be implemented in JAX. However, often JAX is able to provide a alternative API that is purely functional. For example, instead of in-place array updates (x[i] = y), JAX provides an alternative pure indexed update function jax.ops.index_update().
  • NumPy is very aggressive at promoting values to float64 type. JAX sometimes is less aggressive about type promotion.

A small number of NumPy operations that have data-dependent output shapes are incompatible with jax.jit() compilation. The XLA compiler requires that shapes of arrays be known at compile time. While it would be possible to provide a JAX implementation of an API such as numpy.nonzero(), we would be unable to JIT-compile it because the shape of its output depends on the contents of the input data.

Not every function in NumPy is implemented; contributions are welcome!

abs(x) Calculate the absolute value element-wise.
absolute(x) Calculate the absolute value element-wise.
add(x1, x2) Add arguments element-wise.
all(a[, axis, dtype, out, keepdims]) Test whether all array elements along a given axis evaluate to True.
allclose(a, b[, rtol, atol]) Returns True if two arrays are element-wise equal within a tolerance.
alltrue(a[, axis, dtype, out, keepdims]) Test whether all array elements along a given axis evaluate to True.
amax(a[, axis, dtype, out, keepdims]) Return the maximum of an array or maximum along an axis.
amin(a[, axis, dtype, out, keepdims]) Return the minimum of an array or minimum along an axis.
angle(z) Return the angle of the complex argument.
any(a[, axis, dtype, out, keepdims]) Test whether any array element along a given axis evaluates to True.
append(arr, values[, axis]) Append values to the end of an array.
arange(start[, stop, step, dtype]) Return evenly spaced values within a given interval.
arccos(x) Trigonometric inverse cosine, element-wise.
arccosh(x) Inverse hyperbolic cosine, element-wise.
arcsin(x) Inverse sine, element-wise.
arcsinh(x) Inverse hyperbolic sine element-wise.
arctan(x) Trigonometric inverse tangent, element-wise.
arctan2(x1, x2) Element-wise arc tangent of x1/x2 choosing the quadrant correctly.
arctanh(x) Inverse hyperbolic tangent element-wise.
argmax(a[, axis]) Returns the indices of the maximum values along an axis.
argmin(a[, axis]) Returns the indices of the minimum values along an axis.
argsort(a[, axis, kind, order]) Returns the indices that would sort an array.
around(a[, decimals]) Round an array to the given number of decimals.
array(object[, dtype, copy, order, ndmin]) Create an array.
array_repr(arr[, max_line_width, precision, …]) Return the string representation of an array.
array_str(a[, max_line_width, precision, …]) Return a string representation of the data in an array.
asarray(a[, dtype, order]) Convert the input to an array.
atleast_1d(*arys) Convert inputs to arrays with at least one dimension.
atleast_2d(*arys) View inputs as arrays with at least two dimensions.
atleast_3d(*arys) View inputs as arrays with at least three dimensions.
bartlett(*args, **kwargs) Return the Bartlett window.
bitwise_and(x1, x2) Compute the bit-wise AND of two arrays element-wise.
bitwise_not(x) Compute bit-wise inversion, or bit-wise NOT, element-wise.
bitwise_or(x1, x2) Compute the bit-wise OR of two arrays element-wise.
bitwise_xor(x1, x2) Compute the bit-wise XOR of two arrays element-wise.
blackman(*args, **kwargs) Return the Blackman window.
block(arrays) Assemble an nd-array from nested lists of blocks.
broadcast_arrays(*args) Like Numpy’s broadcast_arrays but doesn’t return views.
broadcast_to(arr, shape) Like Numpy’s broadcast_to but doesn’t necessarily return views.
can_cast(from_, to[, casting]) Returns True if cast between data types can occur according to the casting rule.
ceil(x) Return the ceiling of the input, element-wise.
clip(a[, a_min, a_max]) Clip (limit) the values in an array.
column_stack(tup) Stack 1-D arrays as columns into a 2-D array.
concatenate(arrays[, axis]) Join a sequence of arrays along an existing axis.
conj(x) Return the complex conjugate, element-wise.
conjugate(x) Return the complex conjugate, element-wise.
corrcoef(x[, y, rowvar, bias, ddof]) Return Pearson product-moment correlation coefficients.
cos(x) Cosine element-wise.
cosh(x) Hyperbolic cosine, element-wise.
count_nonzero(a[, axis]) Counts the number of non-zero values in the array a.
cov(m[, y, rowvar, bias, ddof, fweights, …]) Estimate a covariance matrix, given data and weights.
cross(a, b[, axisa, axisb, axisc, axis]) Return the cross product of two (arrays of) vectors.
cumsum(a[, axis, dtype]) Return the cumulative sum of the elements along a given axis.
cumprod(a[, axis, dtype]) Return the cumulative product of elements along a given axis.
cumproduct(a[, axis, dtype]) Return the cumulative product of elements along a given axis.
deg2rad(x) Convert angles from degrees to radians.
degrees(x) Convert angles from radians to degrees.
diag(v[, k]) Extract a diagonal or construct a diagonal array.
diag_indices(n[, ndim]) Return the indices to access the main diagonal of an array.
diagonal(a[, offset, axis1, axis2]) Return specified diagonals.
divide(x1, x2) Returns a true division of the inputs, element-wise.
divmod(x1, x2) Return element-wise quotient and remainder simultaneously.
dot(a, b[, precision]) Dot product of two arrays.
dsplit(ary, indices_or_sections) Split array into multiple sub-arrays along the 3rd axis (depth).
dstack(tup) Stack arrays in sequence depth wise (along third axis).
einsum(*operands, **kwargs) Evaluates the Einstein summation convention on the operands.
equal(x1, x2) Return (x1 == x2) element-wise.
empty(shape[, dtype]) Return a new array of given shape and type, filled with zeros.
empty_like(x[, dtype]) Return an array of zeros with the same shape and type as a given array.
exp(x) Calculate the exponential of all elements in the input array.
exp2(x) Calculate 2**p for all p in the input array.
expand_dims(a, axis) Expand the shape of an array.
expm1(x) Calculate exp(x) - 1 for all elements in the array.
eye(N[, M, k, dtype]) Return a 2-D array with ones on the diagonal and zeros elsewhere.
fabs(x) Compute the absolute values element-wise.
fix(x[, out]) Round to nearest integer towards zero.
flip(m[, axis]) Reverse the order of elements in an array along the given axis.
fliplr(m) Flip array in the left/right direction.
flipud(m) Flip array in the up/down direction.
float_power(x1, x2) First array elements raised to powers from second array, element-wise.
floor(x) Return the floor of the input, element-wise.
floor_divide(x1, x2) Return the largest integer smaller or equal to the division of the inputs.
fmod(x1, x2) Return the element-wise remainder of division.
full(shape, fill_value[, dtype]) Return a new array of given shape and type, filled with fill_value.
full_like(a, fill_value[, dtype]) Return a full array with the same shape and type as a given array.
gcd(x1, x2) Returns the greatest common divisor of |x1| and |x2|
geomspace(start, stop[, num, endpoint, …]) Return numbers spaced evenly on a log scale (a geometric progression).
greater(x1, x2) Return the truth value of (x1 > x2) element-wise.
greater_equal(x1, x2) Return the truth value of (x1 >= x2) element-wise.
hamming(*args, **kwargs) Return the Hamming window.
hanning(*args, **kwargs) Return the Hanning window.
heaviside(x1, x2) Compute the Heaviside step function.
hsplit(ary, indices_or_sections) Split an array into multiple sub-arrays horizontally (column-wise).
hstack(tup) Stack arrays in sequence horizontally (column wise).
hypot(x1, x2) Given the “legs” of a right triangle, return its hypotenuse.
identity(n[, dtype]) Return the identity array.
imag(val) Return the imaginary part of the complex argument.
inner(a, b[, precision]) Inner product of two arrays.
isclose(a, b[, rtol, atol, equal_nan]) Returns a boolean array where two arrays are element-wise equal within a
iscomplex(x) Returns a bool array, where True if input element is complex.
isfinite(x) Test element-wise for finiteness (not infinity or not Not a Number).
isinf(x) Test element-wise for positive or negative infinity.
isnan(x) Test element-wise for NaN and return result as a boolean array.
isneginf(infinity, x) Test element-wise for negative infinity, return result as bool array.
isposinf(infinity, x) Test element-wise for positive infinity, return result as bool array.
isreal(x) Returns a bool array, where True if input element is real.
isscalar(num) Returns True if the type of element is a scalar type.
issubdtype(arg1, arg2) Returns True if first argument is a typecode lower/equal in type hierarchy.
issubsctype(arg1, arg2) Determine if the first argument is a subclass of the second argument.
ix_(*args) Construct an open mesh from multiple sequences.
kaiser(*args, **kwargs) Return the Kaiser window.
kron(a, b) Kronecker product of two arrays.
lcm(x1, x2) Returns the lowest common multiple of |x1| and |x2|
left_shift(x1, x2) Shift the bits of an integer to the left.
less(x1, x2) Return the truth value of (x1 < x2) element-wise.
less_equal(x1, x2) Return the truth value of (x1 =< x2) element-wise.
linspace(start, stop[, num, endpoint, …]) Return evenly spaced numbers over a specified interval.
log(x) Natural logarithm, element-wise.
log10(x) Return the base 10 logarithm of the input array, element-wise.
log1p(x) Return the natural logarithm of one plus the input array, element-wise.
log2(x) Base-2 logarithm of x.
logaddexp(x1, x2) Logarithm of the sum of exponentiations of the inputs.
logaddexp2(x1, x2) Logarithm of the sum of exponentiations of the inputs in base-2.
logical_and(*args) Compute the truth value of x1 AND x2 element-wise.
logical_not(*args) Compute the truth value of NOT x element-wise.
logical_or(*args) Compute the truth value of x1 OR x2 element-wise.
logical_xor(*args) Compute the truth value of x1 XOR x2, element-wise.
logspace(start, stop[, num, endpoint, base, …]) Return numbers spaced evenly on a log scale.
matmul(a, b[, precision]) Matrix product of two arrays.
max(a[, axis, dtype, out, keepdims]) Return the maximum of an array or maximum along an axis.
maximum(x1, x2) Element-wise maximum of array elements.
mean(a[, axis, dtype, out, keepdims]) Compute the arithmetic mean along the specified axis.
median(a[, axis, out, overwrite_input, keepdims]) Compute the median along the specified axis.
meshgrid(*args, **kwargs) Return coordinate matrices from coordinate vectors.
min(a[, axis, dtype, out, keepdims]) Return the minimum of an array or minimum along an axis.
minimum(x1, x2) Element-wise minimum of array elements.
mod(x1, x2) Return element-wise remainder of division.
moveaxis(a, source, destination) Move axes of an array to new positions.
msort(a) Return a copy of an array sorted along the first axis.
multiply(x1, x2) Multiply arguments element-wise.
nan_to_num(x[, copy]) Replace NaN with zero and infinity with large finite numbers (default
nancumprod(a[, axis, dtype]) Return the cumulative product of array elements over a given axis treating Not a
nancumsum(a[, axis, dtype]) Return the cumulative sum of array elements over a given axis treating Not a
nanmax(a[, axis, out, keepdims]) Return the maximum of an array or maximum along an axis, ignoring any
nanmin(a[, axis, out, keepdims]) Return minimum of an array or minimum along an axis, ignoring any NaNs.
nanprod(a[, axis, out, keepdims]) Return the product of array elements over a given axis treating Not a
nansum(a[, axis, out, keepdims]) Return the sum of array elements over a given axis treating Not a
negative(x) Numerical negative, element-wise.
nextafter(x1, x2) Return the next floating-point value after x1 towards x2, element-wise.
nonzero(a) Return the indices of the elements that are non-zero.
not_equal(x1, x2) Return (x1 != x2) element-wise.
ones(shape[, dtype]) Return a new array of given shape and type, filled with ones.
ones_like(x[, dtype]) Return an array of ones with the same shape and type as a given array.
outer(a, b[, out]) Compute the outer product of two vectors.
pad(array, pad_width[, mode, constant_values]) Pad an array.
percentile(a, q[, axis, out, …]) Compute the q-th percentile of the data along the specified axis.
polyval(p, x) Evaluate a polynomial at specific values.
power(x1, x2) First array elements raised to powers from second array, element-wise.
positive(x) Numerical positive, element-wise.
prod(a[, axis, dtype, out, keepdims]) Return the product of array elements over a given axis.
product(a[, axis, dtype, out, keepdims]) Return the product of array elements over a given axis.
promote_types(a, b) Returns the type to which a binary operation should cast its arguments.
ptp(a[, axis, out, keepdims]) Range of values (maximum - minimum) along an axis.
quantile(a, q[, axis, out, overwrite_input, …]) Compute the q-th quantile of the data along the specified axis.
rad2deg(x) Convert angles from radians to degrees.
radians(x) Convert angles from degrees to radians.
ravel(a[, order]) Return a contiguous flattened array.
real(val) Return the real part of the complex argument.
reciprocal(x) Return the reciprocal of the argument, element-wise.
remainder(x1, x2) Return element-wise remainder of division.
repeat(a, repeats[, axis]) Repeat elements of an array.
reshape(a, newshape[, order]) Gives a new shape to an array without changing its data.
result_type(*args) Returns the type that results from applying the NumPy
right_shift(x1, x2) Shift the bits of an integer to the right.
roll(a, shift[, axis]) Roll array elements along a given axis.
rot90(m[, k, axes]) Rotate an array by 90 degrees in the plane specified by axes.
round(a[, decimals]) Round an array to the given number of decimals.
row_stack(tup) Stack arrays in sequence vertically (row wise).
select(condlist, choicelist[, default]) Return an array drawn from elements in choicelist, depending on conditions.
sign(x) Returns an element-wise indication of the sign of a number.
signbit(x) Returns element-wise True where signbit is set (less than zero).
sin(x) Trigonometric sine, element-wise.
sinc(x) Return the sinc function.
sinh(x) Hyperbolic sine, element-wise.
sometrue(a[, axis, dtype, out, keepdims]) Test whether any array element along a given axis evaluates to True.
sort(a[, axis, kind, order]) Return a sorted copy of an array.
split(ary, indices_or_sections[, axis]) Split an array into multiple sub-arrays as views into ary.
sqrt(x) Return the non-negative square-root of an array, element-wise.
square(x) Return the element-wise square of the input.
squeeze(a[, axis]) Remove single-dimensional entries from the shape of an array.
stack(arrays[, axis]) Join a sequence of arrays along a new axis.
std(a[, axis, dtype, out, ddof, keepdims]) Compute the standard deviation along the specified axis.
subtract(x1, x2) Subtract arguments, element-wise.
sum(a[, axis, dtype, out, keepdims]) Sum of array elements over a given axis.
swapaxes(a, axis1, axis2) Interchange two axes of an array.
take(a, indices[, axis, out, mode]) Take elements from an array along an axis.
take_along_axis(arr, indices, axis) Take values from the input array by matching 1d index and data slices.
tan(x) Compute tangent element-wise.
tanh(x) Compute hyperbolic tangent element-wise.
tensordot(a, b[, axes, precision]) Compute tensor dot product along specified axes.
tile(a, reps) Construct an array by repeating A the number of times given by reps.
trace(a[, offset, axis1, axis2, dtype, out]) Return the sum along diagonals of the array.
transpose(a[, axes]) Permute the dimensions of an array.
tri(N[, M, k, dtype]) An array with ones at and below the given diagonal and zeros elsewhere.
tril(m[, k]) Lower triangle of an array.
tril_indices(*args, **kwargs) Return the indices for the lower-triangle of an (n, m) array.
triu(m[, k]) Upper triangle of an array.
triu_indices(*args, **kwargs) Return the indices for the upper-triangle of an (n, m) array.
true_divide(x1, x2) Returns a true division of the inputs, element-wise.
vander(x[, N, increasing]) Generate a Vandermonde matrix.
var(a[, axis, dtype, out, ddof, keepdims]) Compute the variance along the specified axis.
vdot(a, b[, precision]) Return the dot product of two vectors.
vsplit(ary, indices_or_sections) Split an array into multiple sub-arrays vertically (row-wise).
vstack(tup) Stack arrays in sequence vertically (row wise).
where(condition[, x, y]) Return elements chosen from x or y depending on condition.
zeros(shape[, dtype]) Return a new array of given shape and type, filled with zeros.
zeros_like(x[, dtype]) Return an array of zeros with the same shape and type as a given array.

jax.numpy.fft

fft(a[, n, axis, norm]) Compute the one-dimensional discrete Fourier Transform.
ifft(a[, n, axis, norm]) Compute the one-dimensional inverse discrete Fourier Transform.
fft2(a[, s, axes, norm]) Compute the 2-dimensional discrete Fourier Transform
ifft2(a[, s, axes, norm]) Compute the 2-dimensional inverse discrete Fourier Transform.
fftn(a[, s, axes, norm]) Compute the N-dimensional discrete Fourier Transform.
ifftn(a[, s, axes, norm]) Compute the N-dimensional inverse discrete Fourier Transform.
rfft(a[, n, axis, norm]) Compute the one-dimensional discrete Fourier Transform for real input.
irfft(a[, n, axis, norm]) Compute the inverse of the n-point DFT for real input.
rfft2(a[, s, axes, norm]) Compute the 2-dimensional FFT of a real array.
irfft2(a[, s, axes, norm]) Compute the 2-dimensional inverse FFT of a real array.
rfftn(a[, s, axes, norm]) Compute the N-dimensional discrete Fourier Transform for real input.
irfftn(a[, s, axes, norm]) Compute the inverse of the N-dimensional FFT of real input.
fftfreq(n[, d]) Return the Discrete Fourier Transform sample frequencies.
rfftfreq(n[, d]) Return the Discrete Fourier Transform sample frequencies
fftshift(x[, axes]) Shift the zero-frequency component to the center of the spectrum.
ifftshift(x[, axes]) The inverse of fftshift. Although identical for even-length x, the

jax.numpy.linalg

cholesky(a) Cholesky decomposition.
det(a) Compute the determinant of an array.
eig(a) Compute the eigenvalues and right eigenvectors of a square array.
eigh(a[, UPLO, symmetrize_input]) Return the eigenvalues and eigenvectors of a complex Hermitian
eigvals(a) Compute the eigenvalues of a general matrix.
eigvalsh(a[, UPLO]) Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
inv(a) Compute the (multiplicative) inverse of a matrix.
matrix_power(a, n) Raise a square matrix to the (integer) power n.
matrix_rank(M[, tol]) Return matrix rank of array using SVD method
norm(x[, ord, axis, keepdims]) Matrix or vector norm.
pinv(a[, rcond]) Compute the (Moore-Penrose) pseudo-inverse of a matrix.
qr(a[, mode]) Compute the qr factorization of a matrix.
slogdet Compute the sign and (natural) logarithm of the determinant of an array.
solve(a, b) Solve a linear matrix equation, or system of linear scalar equations.
svd(a[, full_matrices, compute_uv]) Singular Value Decomposition.