jax.numpy module#

Implements the NumPy API, using the primitives in jax.lax.

While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.

  • Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place cannot be implemented in JAX. However, often JAX is able to provide an alternative API that is purely functional. For example, instead of in-place array updates (x[i] = y), JAX provides an alternative pure indexed update function x.at[i].set(y) (see ndarray.at).

  • Relatedly, some NumPy functions often return views of arrays when possible (examples are transpose() and reshape()). JAX versions of such functions will return copies instead, although such are often optimized away by XLA when sequences of operations are compiled using jax.jit().

  • NumPy is very aggressive at promoting values to float64 type. JAX sometimes is less aggressive about type promotion (See Type promotion semantics).

  • Some NumPy routines have data-dependent output shapes (examples include unique() and nonzero()). Because the XLA compiler requires array shapes to be known at compile time, such operations are not compatible with JIT. For this reason, JAX adds an optional size argument to such functions which may be specified statically in order to use them with JIT.

Nearly all applicable NumPy functions are implemented in the jax.numpy namespace; they are listed below.

ndarray.at

Helper property for index update functionality.

abs(x, /)

Alias of jax.numpy.absolute().

absolute(x, /)

Calculate the absolute value element-wise.

acos(x, /)

Trigonometric inverse cosine, element-wise.

acosh(x, /)

Inverse hyperbolic cosine, element-wise.

add(x1, x2, /)

Add arguments element-wise.

all(a[, axis, out, keepdims, where])

Test whether all array elements along a given axis evaluate to True.

allclose(a, b[, rtol, atol, equal_nan])

Returns True if two arrays are element-wise equal within a tolerance.

amax(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

amin(a[, axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

angle(z[, deg])

Return the angle of a complex valued number or array.

any(a[, axis, out, keepdims, where])

Test whether any array element along a given axis evaluates to True.

append(arr, values[, axis])

Return a new array with values appended to the end of the original array.

apply_along_axis(func1d, axis, arr, *args, ...)

Apply a function to 1-D slices along the given axis.

apply_over_axes(func, a, axes)

Apply a function repeatedly over multiple axes.

arange(start[, stop, step, dtype])

Return evenly spaced values within a given interval.

arccos(x, /)

Trigonometric inverse cosine, element-wise.

arccosh(x, /)

Inverse hyperbolic cosine, element-wise.

arcsin(x, /)

Inverse sine, element-wise.

arcsinh(x, /)

Inverse hyperbolic sine element-wise.

arctan(x, /)

Trigonometric inverse tangent, element-wise.

arctan2(x1, x2, /)

Element-wise arc tangent of x1/x2 choosing the quadrant correctly.

arctanh(x, /)

Inverse hyperbolic tangent element-wise.

argmax(a[, axis, out, keepdims])

Returns the indices of the maximum values along an axis.

argmin(a[, axis, out, keepdims])

Returns the indices of the minimum values along an axis.

argpartition(a, kth[, axis])

Returns indices that partially sort an array.

argsort(a[, axis, kind, order, stable, ...])

Returns the indices that would sort an array.

argwhere(a, *[, size, fill_value])

Find the indices of nonzero array elements

around(a[, decimals, out])

Round an array to the given number of decimals.

array(object[, dtype, copy, order, ndmin])

Create an array.

array_equal(a1, a2[, equal_nan])

True if two arrays have the same shape and elements, False otherwise.

array_equiv(a1, a2)

Returns True if input arrays are shape consistent and all elements equal.

array_repr(arr[, max_line_width, precision, ...])

Return the string representation of an array.

array_split(ary, indices_or_sections[, axis])

Split an array into multiple sub-arrays.

array_str(a[, max_line_width, precision, ...])

Return a string representation of the data in an array.

asarray(a[, dtype, order, copy])

Convert the input to an array.

asin(x, /)

Inverse sine, element-wise.

asinh(x, /)

Inverse hyperbolic sine element-wise.

astype(x, dtype, /, *[, copy, device])

This is implemented via jax.lax.convert_element_type(), which may have slightly different behavior than numpy.astype() in some cases.

atan(x, /)

Trigonometric inverse tangent, element-wise.

atanh(x, /)

Inverse hyperbolic tangent element-wise.

atan2(x1, x2, /)

Element-wise arc tangent of x1/x2 choosing the quadrant correctly.

atleast_1d()

Convert inputs to arrays with at least one dimension.

atleast_2d()

View inputs as arrays with at least two dimensions.

atleast_3d()

View inputs as arrays with at least three dimensions.

average()

Compute the weighted average along the specified axis.

bartlett(M)

Return the Bartlett window.

bincount(x[, weights, minlength, length])

Count the number of occurrences of each value in an integer array.

bitwise_and(x1, x2, /)

Compute the bit-wise AND of two arrays element-wise.

bitwise_count(x, /)

Counts the number of 1 bits in the binary representation of the absolute value of each element of x.

bitwise_invert(x, /)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

bitwise_left_shift(x1, x2, /)

Shift the bits of an integer to the left.

bitwise_not(x, /)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

bitwise_or(x1, x2, /)

Compute the bit-wise OR of two arrays element-wise.

bitwise_right_shift(x1, x2, /)

Shift the bits of an integer to the right.

bitwise_xor(x1, x2, /)

Compute the bit-wise XOR of two arrays element-wise.

blackman(M)

Return the Blackman window.

block(arrays)

Assemble an nd-array from nested lists of blocks.

bool_(x)

broadcast_arrays(*args)

Broadcast any number of arrays against each other.

broadcast_shapes()

Broadcast the input shapes into a single shape.

broadcast_to(array, shape)

Broadcast an array to a new shape.

c_

Concatenate slices, scalars and array-like objects along the last axis.

can_cast(from_, to[, casting])

Returns True if cast between data types can occur according to the casting rule.

cbrt(x, /)

Return the cube-root of an array, element-wise.

cdouble

alias of complex128

ceil(x, /)

Return the ceiling of the input, element-wise.

character()

Abstract base class of all character string scalar types.

choose(a, choices[, out, mode])

Construct an array from an index array and a list of arrays to choose from.

clip([x, min, max, a, a_min, a_max])

Clip (limit) the values in an array.

column_stack(tup)

Stack 1-D arrays as columns into a 2-D array.

complex_

alias of complex128

complex128(x)

complex64(x)

complexfloating()

Abstract base class of all complex number scalar types that are made up of floating-point numbers.

ComplexWarning

The warning raised when casting a complex dtype to a real dtype.

compress(condition, a[, axis, size, ...])

Compress an array along a given axis using a boolean condition.

concat(arrays, /, *[, axis])

concatenate(arrays[, axis, dtype])

Join a sequence of arrays along an existing axis.

conj(x, /)

Return the complex conjugate, element-wise.

conjugate(x, /)

Return the complex conjugate, element-wise.

convolve(a, v[, mode, precision, ...])

Returns the discrete, linear convolution of two one-dimensional sequences.

copy(a[, order])

Return an array copy of the given object.

copysign(x1, x2, /)

Change the sign of x1 to that of x2, element-wise.

corrcoef(x[, y, rowvar])

Return Pearson product-moment correlation coefficients.

correlate(a, v[, mode, precision, ...])

Cross-correlation of two 1-dimensional sequences.

cos(x, /)

Cosine element-wise.

cosh(x, /)

Hyperbolic cosine, element-wise.

count_nonzero(a[, axis, keepdims])

Counts the number of non-zero values in the array a.

cov(m[, y, rowvar, bias, ddof, fweights, ...])

Estimate a covariance matrix, given data and weights.

cross(a, b[, axisa, axisb, axisc, axis])

Return the cross product of two (arrays of) vectors.

csingle

alias of complex64

cumprod(a[, axis, dtype, out])

Return the cumulative product of elements along a given axis.

cumsum(a[, axis, dtype, out])

Return the cumulative sum of the elements along a given axis.

cumulative_sum(x, /, *[, axis, dtype, ...])

deg2rad(x, /)

Convert angles from degrees to radians.

degrees(x, /)

Convert angles from radians to degrees.

delete(arr, obj[, axis, assume_unique_indices])

Delete entry or entries from an array.

diag(v[, k])

Extract a diagonal or construct a diagonal array.

diag_indices(n[, ndim])

Return the indices to access the main diagonal of an array.

diag_indices_from(arr)

Return the indices to access the main diagonal of an n-dimensional array.

diagflat(v[, k])

Create a two-dimensional array with the flattened input as a diagonal.

diagonal(a[, offset, axis1, axis2])

Return specified diagonals.

diff(a[, n, axis, prepend, append])

Calculate the n-th discrete difference along the given axis.

digitize(x, bins[, right])

Return the indices of the bins to which each value in input array belongs.

divide(x1, x2, /)

Divide arguments element-wise.

divmod(x1, x2, /)

Return element-wise quotient and remainder simultaneously.

dot(a, b, *[, precision, preferred_element_type])

Compute the dot product of two arrays.

double

alias of float64

dsplit(ary, indices_or_sections)

Split array into multiple sub-arrays along the 3rd axis (depth).

dstack(tup[, dtype])

Stack arrays in sequence depth wise (along third axis).

dtype(dtype[, align, copy])

Create a data type object.

ediff1d(ary[, to_end, to_begin])

The differences between consecutive elements of an array.

einsum()

Einstein summation

einsum_path()

Evaluates the optimal contraction path without evaluating the einsum.

empty(shape[, dtype, device])

Return a new array of given shape and type, without initializing entries.

empty_like(prototype[, dtype, shape, device])

Return a new array with the same shape and type as a given array.

equal(x1, x2, /)

Return (x1 == x2) element-wise.

exp(x, /)

Calculate the exponential of all elements in the input array.

exp2(x, /)

Calculate 2**p for all p in the input array.

expand_dims(a, axis)

Insert dimensions of length 1 into array

expm1(x, /)

Calculate exp(x) - 1 for all elements in the array.

extract(condition, arr, *[, size, fill_value])

Return the elements of an array that satisfy a condition.

eye(N[, M, k, dtype])

Return a 2-D array with ones on the diagonal and zeros elsewhere.

fabs(x, /)

Compute the absolute values element-wise.

fill_diagonal(a, val[, wrap, inplace])

Fill the main diagonal of the given array of any dimensionality.

finfo(dtype)

Machine limits for floating point types.

fix(x[, out])

Round to nearest integer towards zero.

flatnonzero(a, *[, size, fill_value])

Return indices of nonzero elements in a flattened array

flexible()

Abstract base class of all scalar types without predefined length.

flip(m[, axis])

Reverse the order of elements of an array along the given axis.

fliplr(m)

Reverse the order of elements of an array along axis 1.

flipud(m)

Reverse the order of elements of an array along axis 0.

float_

alias of float64

float_power(x1, x2, /)

First array elements raised to powers from second array, element-wise.

float16(x)

float32(x)

float64(x)

floating()

Abstract base class of all floating-point scalar types.

floor(x, /)

Return the floor of the input, element-wise.

floor_divide(x1, x2, /)

Return the largest integer smaller or equal to the division of the inputs.

fmax(x1, x2)

Element-wise maximum of array elements.

fmin(x1, x2)

Element-wise minimum of array elements.

fmod(x1, x2, /)

Returns the element-wise remainder of division.

frexp(x, /)

Decompose the elements of x into mantissa and twos exponent.

frombuffer(buffer[, dtype, count, offset])

Interpret a buffer as a 1-dimensional array.

fromfile(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromfile.

fromfunction(function, shape, *[, dtype])

Construct an array by executing a function over each coordinate.

fromiter(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromiter.

frompyfunc(func, /, nin, nout, *[, identity])

Create a JAX ufunc from an arbitrary JAX-compatible scalar function.

fromstring(string[, dtype, count])

A new 1-D array initialized from text data in a string.

from_dlpack(x, /, *[, device, copy])

Create a NumPy array from an object implementing the __dlpack__

full(shape, fill_value[, dtype, device])

Return a new array of given shape and type, filled with fill_value.

full_like(a, fill_value[, dtype, shape, device])

Return a full array with the same shape and type as a given array.

gcd(x1, x2)

Returns the greatest common divisor of |x1| and |x2|

generic()

Base class for numpy scalar types.

geomspace(start, stop[, num, endpoint, ...])

Return numbers spaced evenly on a log scale (a geometric progression).

get_printoptions()

Return the current print options.

gradient(f, *varargs[, axis, edge_order])

Return the gradient of an N-dimensional array.

greater(x1, x2, /)

Return the truth value of (x1 > x2) element-wise.

greater_equal(x1, x2, /)

Return the truth value of (x1 >= x2) element-wise.

hamming(M)

Return the Hamming window.

hanning(M)

Return the Hanning window.

heaviside(x1, x2, /)

Compute the Heaviside step function.

histogram(a[, bins, range, weights, density])

Compute the histogram of a dataset.

histogram_bin_edges(a[, bins, range, weights])

Function to calculate only the edges of the bins used by the histogram

histogram2d(x, y[, bins, range, weights, ...])

Compute the bi-dimensional histogram of two data samples.

histogramdd(sample[, bins, range, weights, ...])

Compute the multidimensional histogram of some data.

hsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays horizontally (column-wise).

hstack(tup[, dtype])

Stack arrays in sequence horizontally (column wise).

hypot(x1, x2, /)

Given the "legs" of a right triangle, return its hypotenuse.

i0

Modified Bessel function of the first kind, order 0.

identity(n[, dtype])

Return the identity array.

iinfo(int_type)

imag(val, /)

Return the imaginary part of the complex argument.

index_exp

A nicer way to build up index tuples for arrays.

indices()

Return an array representing the indices of a grid.

inexact()

Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floating-point numbers.

inner(a, b, *[, precision, ...])

Compute the inner product of two arrays.

insert(arr, obj, values[, axis])

Insert values along the given axis before the given indices.

int_

alias of int64

int16(x)

int32(x)

int64(x)

int8(x)

integer()

Abstract base class of all integer scalar types.

interp(x, xp, fp[, left, right, period])

One-dimensional linear interpolation for monotonically increasing sample points.

intersect1d(ar1, ar2[, assume_unique, ...])

Compute the set intersection of two 1D arrays.

invert(x, /)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

isclose(a, b[, rtol, atol, equal_nan])

Returns a boolean array where two arrays are element-wise equal within a

iscomplex(x)

Returns a bool array, where True if input element is complex.

iscomplexobj(x)

Check for a complex type or an array of complex numbers.

isdtype(dtype, kind)

Returns a boolean indicating whether a provided dtype is of a specified kind.

isfinite(x, /)

Test element-wise for finiteness (not infinity and not Not a Number).

isin(element, test_elements[, ...])

Determine whether elements in element appear in test_elements.

isinf(x, /)

Test element-wise for positive or negative infinity.

isnan(x, /)

Test element-wise for NaN and return result as a boolean array.

isneginf(x, /[, out])

Test element-wise for negative infinity, return result as bool array.

isposinf(x, /[, out])

Test element-wise for positive infinity, return result as bool array.

isreal(x)

Returns a bool array, where True if input element is real.

isrealobj(x)

Return True if x is a not complex type or an array of complex numbers.

isscalar(element)

Returns True if the type of element is a scalar type.

issubdtype(arg1, arg2)

Returns True if first argument is a typecode lower/equal in type hierarchy.

iterable(y)

Check whether or not an object can be iterated over.

ix_(*args)

Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.

kaiser(M, beta)

Return the Kaiser window.

kron(a, b)

Kronecker product of two arrays.

lcm(x1, x2)

Returns the lowest common multiple of |x1| and |x2|

ldexp(x1, x2, /)

Returns x1 * 2**x2, element-wise.

left_shift(x1, x2, /)

Shift the bits of an integer to the left.

less(x1, x2, /)

Return the truth value of (x1 < x2) element-wise.

less_equal(x1, x2, /)

Return the truth value of (x1 <= x2) element-wise.

lexsort(keys[, axis])

Perform an indirect stable sort using a sequence of keys.

linspace()

Return evenly spaced numbers over a specified interval.

load(*args, **kwargs)

Load arrays or pickled objects from .npy, .npz or pickled files.

log(x, /)

Natural logarithm, element-wise.

log10(x, /)

Return the base 10 logarithm of the input array, element-wise.

log1p(x, /)

Return the natural logarithm of one plus the input array, element-wise.

log2(x, /)

Base-2 logarithm of x.

logaddexp

Logarithm of the sum of exponentiations of the inputs.

logaddexp2

Logarithm of the sum of exponentiations of the inputs in base-2.

logical_and(*args)

Compute the truth value of x1 AND x2 element-wise.

logical_not(*args)

Compute the truth value of NOT x element-wise.

logical_or(*args)

Compute the truth value of x1 OR x2 element-wise.

logical_xor(*args)

Compute the truth value of x1 XOR x2, element-wise.

logspace(start, stop[, num, endpoint, base, ...])

Return numbers spaced evenly on a log scale.

mask_indices(*args, **kwargs)

Return the indices to access (n, n) arrays, given a masking function.

matmul(a, b, *[, precision, ...])

Perform a matrix multiplication.

matrix_transpose(x, /)

Transpose the last two dimensions of an array.

max(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

maximum(x1, x2, /)

Element-wise maximum of array elements.

mean(a[, axis, dtype, out, keepdims, where])

Compute the arithmetic mean along the specified axis.

median(a[, axis, out, overwrite_input, keepdims])

Compute the median along the specified axis.

meshgrid(*xi[, copy, sparse, indexing])

Return a list of coordinate matrices from coordinate vectors.

mgrid

Return dense multi-dimensional "meshgrid".

min(a[, axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

minimum(x1, x2, /)

Element-wise minimum of array elements.

mod(x1, x2, /)

Returns the element-wise remainder of division.

modf(x, /[, out])

Return the fractional and integral parts of an array, element-wise.

moveaxis(a, source, destination)

Move an array axis to a new position

multiply(x1, x2, /)

Multiply arguments element-wise.

nan_to_num(x[, copy, nan, posinf, neginf])

Replace NaN with zero and infinity with large finite numbers (default

nanargmax(a[, axis, out, keepdims])

Return the indices of the maximum values in the specified axis ignoring

nanargmin(a[, axis, out, keepdims])

Return the indices of the minimum values in the specified axis ignoring

nancumprod(a[, axis, dtype, out])

Return the cumulative product of array elements over a given axis treating Not a

nancumsum(a[, axis, dtype, out])

Return the cumulative sum of array elements over a given axis treating Not a

nanmax(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis, ignoring any

nanmean(a[, axis, dtype, out, keepdims, where])

Compute the arithmetic mean along the specified axis, ignoring NaNs.

nanmedian(a[, axis, out, overwrite_input, ...])

Compute the median along the specified axis, while ignoring NaNs.

nanmin(a[, axis, out, keepdims, initial, where])

Return minimum of an array or minimum along an axis, ignoring any NaNs.

nanpercentile(a, q[, axis, out, ...])

Compute the qth percentile of the data along the specified axis,

nanprod(a[, axis, dtype, out, keepdims, ...])

Return the product of array elements over a given axis treating Not a

nanquantile(a, q[, axis, out, ...])

Compute the qth quantile of the data along the specified axis,

nanstd(a[, axis, dtype, out, ddof, ...])

Compute the standard deviation along the specified axis, while

nansum(a[, axis, dtype, out, keepdims, ...])

Return the sum of array elements over a given axis treating Not a

nanvar(a[, axis, dtype, out, ddof, ...])

Compute the variance along the specified axis, while ignoring NaNs.

ndarray

alias of Array

ndim(a)

Return the number of dimensions of an array.

negative(x, /)

Numerical negative, element-wise.

nextafter(x1, x2, /)

Return the next floating-point value after x1 towards x2, element-wise.

nonzero(a, *[, size, fill_value])

Return indices of nonzero elements of an array.

not_equal(x1, x2, /)

Return (x1 != x2) element-wise.

number()

Abstract base class of all numeric scalar types.

object_

Any Python object.

ogrid

Return open multi-dimensional "meshgrid".

ones(shape[, dtype, device])

Return a new array of given shape and type, filled with ones.

ones_like(a[, dtype, shape, device])

Return an array of ones with the same shape and type as a given array.

outer(a, b[, out])

Compute the outer product of two vectors.

packbits(a[, axis, bitorder])

Packs the elements of a binary-valued array into bits in a uint8 array.

pad(array, pad_width[, mode])

Pad an array.

partition(a, kth[, axis])

Returns a partially-sorted copy of an array.

percentile(a, q[, axis, out, ...])

Compute the q-th percentile of the data along the specified axis.

permute_dims(a, /, axes)

piecewise(x, condlist, funclist, *args, **kw)

Evaluate a piecewise-defined function.

place(arr, mask, vals, *[, inplace])

Change elements of an array based on conditional and input values.

poly(seq_of_zeros)

Find the coefficients of a polynomial with the given sequence of roots.

polyadd(a1, a2)

Find the sum of two polynomials.

polyder(p[, m])

Return the derivative of the specified order of a polynomial.

polydiv(u, v, *[, trim_leading_zeros])

Returns the quotient and remainder of polynomial division.

polyfit(x, y, deg[, rcond, full, w, cov])

Least squares polynomial fit.

polyint(p[, m, k])

Return an antiderivative (indefinite integral) of a polynomial.

polymul(a1, a2, *[, trim_leading_zeros])

Find the product of two polynomials.

polysub(a1, a2)

Difference (subtraction) of two polynomials.

polyval(p, x, *[, unroll])

Evaluate a polynomial at specific values.

positive(x, /)

Numerical positive, element-wise.

pow(x1, x2, /)

First array elements raised to powers from second array, element-wise.

power(x1, x2, /)

First array elements raised to powers from second array, element-wise.

printoptions(*args, **kwargs)

Context manager for setting print options.

prod(a[, axis, dtype, out, keepdims, ...])

Return the product of array elements over a given axis.

promote_types(a, b)

Returns the type to which a binary operation should cast its arguments.

ptp(a[, axis, out, keepdims])

Range of values (maximum - minimum) along an axis.

put(a, ind, v[, mode, inplace])

Replaces specified elements of an array with given values.

quantile(a, q[, axis, out, overwrite_input, ...])

Compute the q-th quantile of the data along the specified axis.

r_

Concatenate slices, scalars and array-like objects along the first axis.

rad2deg(x, /)

Convert angles from radians to degrees.

radians(x, /)

Convert angles from degrees to radians.

ravel(a[, order])

Flatten array into a 1-dimensional shape.

ravel_multi_index(multi_index, dims[, mode, ...])

Convert multi-dimensional indices into flat indices.

real(val, /)

Return the real part of the complex argument.

reciprocal(x, /)

Return the reciprocal of the argument, element-wise.

remainder(x1, x2, /)

Returns the element-wise remainder of division.

repeat(a, repeats[, axis, total_repeat_length])

Repeat each element of an array after themselves

reshape(a[, shape, order, newshape])

Return a reshaped copy of an array.

resize(a, new_shape)

Return a new array with the specified shape.

result_type(*args)

Returns the type that results from applying the NumPy

right_shift(x1, x2, /)

Right shift the bits of x1 to the amount specified in x2.

rint(x, /)

Round elements of the array to the nearest integer.

roll(a, shift[, axis])

Roll array elements along a given axis.

rollaxis(a, axis[, start])

Roll the specified axis to a given position.

roots(p, *[, strip_zeros])

Return the roots of a polynomial with coefficients given in p.

rot90(m[, k, axes])

Rotate an array by 90 degrees in the plane specified by axes.

round(a[, decimals, out])

Round an array to the given number of decimals.

round_(a[, decimals, out])

Round an array to the given number of decimals.

s_

A nicer way to build up index tuples for arrays.

save(file, arr[, allow_pickle, fix_imports])

Save an array to a binary file in NumPy .npy format.

savez(file, *args, **kwds)

Save several arrays into a single file in uncompressed .npz format.

searchsorted(a, v[, side, sorter, method])

Perform a binary search within a sorted array.

select(condlist, choicelist[, default])

Return an array drawn from elements in choicelist, depending on conditions.

set_printoptions([precision, threshold, ...])

Set printing options.

setdiff1d(ar1, ar2[, assume_unique, size, ...])

Compute the set difference of two 1D arrays.

setxor1d(ar1, ar2[, assume_unique])

Compute the set-wise xor of elements in two arrays.

shape(a)

Return the shape of an array.

sign(x, /)

Returns an element-wise indication of the sign of a number.

signbit(x, /)

Returns element-wise True where signbit is set (less than zero).

signedinteger()

Abstract base class of all signed integer scalar types.

sin(x, /)

Trigonometric sine, element-wise.

sinc(x, /)

Return the normalized sinc function.

single

alias of float32

sinh(x, /)

Hyperbolic sine, element-wise.

size(a[, axis])

Return the number of elements along a given axis.

sort(a[, axis, kind, order, stable, descending])

Return a sorted copy of an array.

sort_complex(a)

Sort a complex array using the real part first, then the imaginary part.

split(ary, indices_or_sections[, axis])

Split an array into multiple sub-arrays as views into ary.

sqrt(x, /)

Return the non-negative square-root of an array, element-wise.

square(x, /)

Return the element-wise square of the input.

squeeze(a[, axis])

Remove one or more length-1 axes from array

stack(arrays[, axis, out, dtype])

Join a sequence of arrays along a new axis.

std(a[, axis, dtype, out, ddof, keepdims, ...])

Compute the standard deviation along the specified axis.

subtract(x1, x2, /)

Subtract arguments, element-wise.

sum(a[, axis, dtype, out, keepdims, ...])

Sum of array elements over a given axis.

swapaxes(a, axis1, axis2)

Swap two axes of an array.

take(a, indices[, axis, out, mode, ...])

Take elements from an array.

take_along_axis(arr, indices, axis[, mode, ...])

Take elements from an array.

tan(x, /)

Compute tangent element-wise.

tanh(x, /)

Compute hyperbolic tangent element-wise.

tensordot(a, b[, axes, precision, ...])

Compute the tensor dot product of two N-dimensional arrays.

tile(A, reps)

Construct an array by repeating A the number of times given by reps.

trace(a[, offset, axis1, axis2, dtype, out])

Return the sum along diagonals of the array.

trapezoid(y[, x, dx, axis])

Integrate along the given axis using the composite trapezoidal rule.

transpose(a[, axes])

Return a transposed version of an N-dimensional array.

tri(N[, M, k, dtype])

An array with ones at and below the given diagonal and zeros elsewhere.

tril(m[, k])

Lower triangle of an array.

tril_indices(n[, k, m])

Return the indices for the lower-triangle of an (n, m) array.

tril_indices_from(arr[, k])

Return the indices for the lower-triangle of arr.

trim_zeros(filt[, trim])

Trim the leading and/or trailing zeros from a 1-D array or sequence.

triu(m[, k])

Upper triangle of an array.

triu_indices(n[, k, m])

Return the indices for the upper-triangle of an (n, m) array.

triu_indices_from(arr[, k])

Return the indices for the upper-triangle of arr.

true_divide(x1, x2, /)

Divide arguments element-wise.

trunc(x)

Return the truncated value of the input, element-wise.

ufunc(func, /, nin, nout, *[, name, nargs, ...])

Functions that operate element-by-element on whole arrays.

uint

alias of uint64

uint16(x)

uint32(x)

uint64(x)

uint8(x)

union1d(ar1, ar2, *[, size, fill_value])

Compute the set union of two 1D arrays.

unique(ar[, return_index, return_inverse, ...])

Return the unique values from an array.

unique_all(x, /, *[, size, fill_value])

Return unique values from x, along with indices, inverse indices, and counts.

unique_counts(x, /, *[, size, fill_value])

Return unique values from x, along with counts.

unique_inverse(x, /, *[, size, fill_value])

Return unique values from x, along with indices, inverse indices, and counts.

unique_values(x, /, *[, size, fill_value])

Return unique values from x, along with indices, inverse indices, and counts.

unpackbits(a[, axis, count, bitorder])

Unpacks elements of a uint8 array into a binary-valued output array.

unravel_index(indices, shape)

Convert flat indices into multi-dimensional indices.

unstack(x, /, *[, axis])

unsignedinteger()

Abstract base class of all unsigned integer scalar types.

unwrap(p[, discont, axis, period])

Unwrap by taking the complement of large deltas with respect to the period.

vander(x[, N, increasing])

Generate a Vandermonde matrix.

var(a[, axis, dtype, out, ddof, keepdims, ...])

Compute the variance along the specified axis.

vdot(a, b, *[, precision, ...])

Perform a conjugate multiplication of two 1D vectors.

vecdot(x1, x2, /, *[, axis, precision, ...])

Perform a conjugate multiplication of two batched vectors.

vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

vsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays vertically (row-wise).

vstack(tup[, dtype])

Stack arrays in sequence vertically (row wise).

where()

Select elements from two arrays based on a condition.

zeros(shape[, dtype, device])

Return a new array of given shape and type, filled with zeros.

zeros_like(a[, dtype, shape, device])

Return an array of zeros with the same shape and type as a given array.

jax.numpy.fft#

fft(a[, n, axis, norm])

Compute the one-dimensional discrete Fourier Transform.

fft2(a[, s, axes, norm])

Compute the 2-dimensional discrete Fourier Transform.

fftfreq(n[, d, dtype])

Return the Discrete Fourier Transform sample frequencies.

fftn(a[, s, axes, norm])

Compute the N-dimensional discrete Fourier Transform.

fftshift(x[, axes])

Shift the zero-frequency component to the center of the spectrum.

hfft(a[, n, axis, norm])

Compute the FFT of a signal that has Hermitian symmetry, i.e., a real

ifft(a[, n, axis, norm])

Compute the one-dimensional inverse discrete Fourier Transform.

ifft2(a[, s, axes, norm])

Compute the 2-dimensional inverse discrete Fourier Transform.

ifftn(a[, s, axes, norm])

Compute the N-dimensional inverse discrete Fourier Transform.

ifftshift(x[, axes])

The inverse of fftshift.

ihfft(a[, n, axis, norm])

Compute the inverse FFT of a signal that has Hermitian symmetry.

irfft(a[, n, axis, norm])

Computes the inverse of rfft.

irfft2(a[, s, axes, norm])

Computes the inverse of rfft2.

irfftn(a[, s, axes, norm])

Computes the inverse of rfftn.

rfft(a[, n, axis, norm])

Compute the one-dimensional discrete Fourier Transform for real input.

rfft2(a[, s, axes, norm])

Compute the 2-dimensional FFT of a real array.

rfftfreq(n[, d, dtype])

Return the Discrete Fourier Transform sample frequencies

rfftn(a[, s, axes, norm])

Compute the N-dimensional discrete Fourier Transform for real input.

jax.numpy.linalg#

cholesky(a, *[, upper])

Compute the Cholesky decomposition of a matrix.

cond(x[, p])

Compute the condition number of a matrix.

cross(x1, x2, /, *[, axis])

Compute the cross-product of two 3D vectors

det

Compute the determinant of an array.

diagonal(x, /, *[, offset])

Extract the diagonal of an matrix or stack of matrices.

eig(a)

Compute the eigenvalues and eigenvectors of a square array.

eigh(a[, UPLO, symmetrize_input])

Compute the eigenvalues and eigenvectors of a Hermitian matrix.

eigvals(a)

Compute the eigenvalues of a general matrix.

eigvalsh(a[, UPLO])

Compute the eigenvalues of a Hermitian matrix.

inv(a)

Return the inverse of a square matrix

lstsq(a, b[, rcond, numpy_resid])

Return the least-squares solution to a linear equation.

matmul(x1, x2, /, *[, precision, ...])

Perform a matrix multiplication.

matrix_norm(x, /, *[, keepdims, ord])

Compute the norm of a matrix or stack of matrices.

matrix_power(a, n)

Raise a square matrix to an integer power.

matrix_rank(M[, rtol, tol])

Compute the rank of a matrix.

matrix_transpose(x, /)

Transpose a matrix or stack of matrices.

multi_dot(arrays, *[, precision])

Efficiently compute matrix products between a sequence of arrays.

norm(x[, ord, axis, keepdims])

Compute the norm of a matrix or vector.

outer(x1, x2, /)

Compute the outer product of two 1-dimensional arrays.

pinv(a[, rtol, hermitian, rcond])

Compute the (Moore-Penrose) pseudo-inverse of a matrix.

qr()

Compute the QR decomposition of an array

slogdet(a, *[, method])

Compute the sign and (natural) logarithm of the determinant of an array.

solve(a, b)

Solve a linear system of equations

svd()

Compute the singular value decomposition.

svdvals(x, /)

Compute the singular values of a matrix.

tensordot(x1, x2, /, *[, axes, precision, ...])

Compute the tensor dot product of two N-dimensional arrays.

tensorinv(a[, ind])

Compute the tensor inverse of an array.

tensorsolve(a, b[, axes])

Solve the tensor equation a x = b for x.

trace(x, /, *[, offset, dtype])

Compute the trace of a matrix.

vector_norm(x, /, *[, axis, keepdims, ord])

Compute the vector norm of a vector or batch of vectors.

vecdot(x1, x2, /, *[, axis, precision, ...])

Compute the (batched) vector conjugate dot product of two arrays.

JAX Array#

The JAX Array (along with its alias, jax.numpy.ndarray) is the core array object in JAX: you can think of it as JAX’s equivalent of a numpy.ndarray. Like numpy.ndarray, most users will not need to instantiate Array objects manually, but rather will create them via jax.numpy functions like array(), arange(), linspace(), and others listed above.

Copying and Serialization#

JAX Array objects are designed to work seamlessly with Python standard library tools where appropriate.

With the built-in copy module, when copy.copy() or copy.deepcopy() encounder an Array, it is equivalent to calling the copy() method, which will create a copy of the buffer on the same device as the original array. This will work correctly within traced/JIT-compiled code, though copy operations may be elided by the compiler in this context.

When the built-in pickle module encounters an Array, it will be serialized via a compact bit representation in a similar manner to pickled numpy.ndarray objects. When unpickled, the result will be a new Array object on the default device. This is because in general, pickling and unpickling may take place in different runtime environments, and there is no general way to map the device IDs of one runtime to the device IDs of another. If pickle is used in traced/JIT-compiled code, it will result in a ConcretizationTypeError.