jax.numpy module#

Implements the NumPy API, using the primitives in jax.lax.

While JAX tries to follow the NumPy API as closely as possible, sometimes JAX cannot follow NumPy exactly.

  • Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays in-place cannot be implemented in JAX. However, often JAX is able to provide an alternative API that is purely functional. For example, instead of in-place array updates (x[i] = y), JAX provides an alternative pure indexed update function x.at[i].set(y) (see ndarray.at).

  • Relatedly, some NumPy functions often return views of arrays when possible (examples are transpose() and reshape()). JAX versions of such functions will return copies instead, although such are often optimized away by XLA when sequences of operations are compiled using jax.jit().

  • NumPy is very aggressive at promoting values to float64 type. JAX sometimes is less aggressive about type promotion (See Type promotion semantics).

  • Some NumPy routines have data-dependent output shapes (examples include unique() and nonzero()). Because the XLA compiler requires array shapes to be known at compile time, such operations are not compatible with JIT. For this reason, JAX adds an optional size argument to such functions which may be specified statically in order to use them with JIT.

Nearly all applicable NumPy functions are implemented in the jax.numpy namespace; they are listed below.

ndarray.at

Helper property for index update functionality.

abs(x, /)

Calculate the absolute value element-wise.

absolute(x, /)

Calculate the absolute value element-wise.

acos(x, /)

Trigonometric inverse cosine, element-wise.

acosh(x, /)

Inverse hyperbolic cosine, element-wise.

add(x1, x2, /)

Add arguments element-wise.

all(a[, axis, out, keepdims, where])

Test whether all array elements along a given axis evaluate to True.

allclose(a, b[, rtol, atol, equal_nan])

Returns True if two arrays are element-wise equal within a tolerance.

amax(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

amin(a[, axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

angle(z[, deg])

Return the angle of the complex argument.

any(a[, axis, out, keepdims, where])

Test whether any array element along a given axis evaluates to True.

append(arr, values[, axis])

Append values to the end of an array.

apply_along_axis(func1d, axis, arr, *args, ...)

Apply a function to 1-D slices along the given axis.

apply_over_axes(func, a, axes)

Apply a function repeatedly over multiple axes.

arange(start[, stop, step, dtype])

Return evenly spaced values within a given interval.

arccos(x, /)

Trigonometric inverse cosine, element-wise.

arccosh(x, /)

Inverse hyperbolic cosine, element-wise.

arcsin(x, /)

Inverse sine, element-wise.

arcsinh(x, /)

Inverse hyperbolic sine element-wise.

arctan(x, /)

Trigonometric inverse tangent, element-wise.

arctan2(x1, x2, /)

Element-wise arc tangent of x1/x2 choosing the quadrant correctly.

arctanh(x, /)

Inverse hyperbolic tangent element-wise.

argmax(a[, axis, out, keepdims])

Returns the indices of the maximum values along an axis.

argmin(a[, axis, out, keepdims])

Returns the indices of the minimum values along an axis.

argpartition(a, kth[, axis])

Perform an indirect partition along the given axis using the

argsort(a[, axis, kind, order, stable, ...])

Returns the indices that would sort an array.

argwhere(a, *[, size, fill_value])

Find the indices of array elements that are non-zero, grouped by element.

around(a[, decimals, out])

Round an array to the given number of decimals.

array(object[, dtype, copy, order, ndmin])

Create an array.

array_equal(a1, a2[, equal_nan])

True if two arrays have the same shape and elements, False otherwise.

array_equiv(a1, a2)

Returns True if input arrays are shape consistent and all elements equal.

array_repr(arr[, max_line_width, precision, ...])

Return the string representation of an array.

array_split(ary, indices_or_sections[, axis])

Split an array into multiple sub-arrays.

array_str(a[, max_line_width, precision, ...])

Return a string representation of the data in an array.

asarray(a[, dtype, order, copy])

Convert the input to an array.

asin(x, /)

Inverse sine, element-wise.

asinh(x, /)

Inverse hyperbolic sine element-wise.

astype(x, dtype, /, *[, copy])

This is implemented via jax.lax.convert_element_type(), which may have slightly different behavior than numpy.astype() in some cases.

atan(x, /)

Trigonometric inverse tangent, element-wise.

atanh(x, /)

Inverse hyperbolic tangent element-wise.

atan2(x1, x2, /)

Element-wise arc tangent of x1/x2 choosing the quadrant correctly.

atleast_1d(*arys)

Convert inputs to arrays with at least one dimension.

atleast_2d(*arys)

View inputs as arrays with at least two dimensions.

atleast_3d(*arys)

View inputs as arrays with at least three dimensions.

average(a[, axis, weights, returned, keepdims])

Compute the weighted average along the specified axis.

bartlett(M)

Return the Bartlett window.

bincount(x[, weights, minlength, length])

Count number of occurrences of each value in array of non-negative ints.

bitwise_and(x1, x2, /)

Compute the bit-wise AND of two arrays element-wise.

bitwise_count(x, /)

param x:

bitwise_invert(x, /)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

bitwise_left_shift(x1, x2, /)

Shift the bits of an integer to the left.

bitwise_not(x, /)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

bitwise_or(x1, x2, /)

Compute the bit-wise OR of two arrays element-wise.

bitwise_right_shift(x1, x2, /)

Shift the bits of an integer to the right.

bitwise_xor(x1, x2, /)

Compute the bit-wise XOR of two arrays element-wise.

blackman(M)

Return the Blackman window.

block(arrays)

Assemble an nd-array from nested lists of blocks.

bool_(x)

broadcast_arrays(*args)

Broadcast any number of arrays against each other.

broadcast_shapes(*shapes)

Broadcast the input shapes into a single shape.

broadcast_to(array, shape)

Broadcast an array to a new shape.

c_

Concatenate slices, scalars and array-like objects along the last axis.

can_cast(from_, to[, casting])

Returns True if cast between data types can occur according to the casting rule.

cbrt(x, /)

Return the cube-root of an array, element-wise.

cdouble

alias of complex128

ceil(x, /)

Return the ceiling of the input, element-wise.

character()

Abstract base class of all character string scalar types.

choose(a, choices[, out, mode])

Construct an array from an index array and a list of arrays to choose from.

clip(a[, a_min, a_max, out])

Clip (limit) the values in an array.

column_stack(tup)

Stack 1-D arrays as columns into a 2-D array.

complex_

alias of complex128

complex128(x)

complex64(x)

complexfloating()

Abstract base class of all complex number scalar types that are made up of floating-point numbers.

ComplexWarning

The warning raised when casting a complex dtype to a real dtype.

compress(condition, a[, axis, out])

Return selected slices of an array along given axis.

concat(arrays, /, *[, axis])

param arrays:

concatenate(arrays[, axis, dtype])

Join a sequence of arrays along an existing axis.

conj(x, /)

Return the complex conjugate, element-wise.

conjugate(x, /)

Return the complex conjugate, element-wise.

convolve(a, v[, mode, precision, ...])

Returns the discrete, linear convolution of two one-dimensional sequences.

copy(a[, order])

Return an array copy of the given object.

copysign(x1, x2, /)

Change the sign of x1 to that of x2, element-wise.

corrcoef(x[, y, rowvar])

Return Pearson product-moment correlation coefficients.

correlate(a, v[, mode, precision, ...])

Cross-correlation of two 1-dimensional sequences.

cos(x, /)

Cosine element-wise.

cosh(x, /)

Hyperbolic cosine, element-wise.

count_nonzero(a[, axis, keepdims])

Counts the number of non-zero values in the array a.

cov(m[, y, rowvar, bias, ddof, fweights, ...])

Estimate a covariance matrix, given data and weights.

cross(a, b[, axisa, axisb, axisc, axis])

Return the cross product of two (arrays of) vectors.

csingle

alias of complex64

cumprod(a[, axis, dtype, out])

Return the cumulative product of elements along a given axis.

cumsum(a[, axis, dtype, out])

Return the cumulative sum of the elements along a given axis.

deg2rad(x, /)

Convert angles from degrees to radians.

degrees(x, /)

Convert angles from radians to degrees.

delete(arr, obj[, axis, assume_unique_indices])

Return a new array with sub-arrays along an axis deleted.

diag(v[, k])

Extract a diagonal or construct a diagonal array.

diag_indices(n[, ndim])

Return the indices to access the main diagonal of an array.

diag_indices_from(arr)

Return the indices to access the main diagonal of an n-dimensional array.

diagflat(v[, k])

Create a two-dimensional array with the flattened input as a diagonal.

diagonal(a[, offset, axis1, axis2])

Return specified diagonals.

diff(a[, n, axis, prepend, append])

Calculate the n-th discrete difference along the given axis.

digitize(x, bins[, right])

Return the indices of the bins to which each value in input array belongs.

divide(x1, x2, /)

Divide arguments element-wise.

divmod(x1, x2, /)

Return element-wise quotient and remainder simultaneously.

dot(a, b, *[, precision, preferred_element_type])

Dot product of two arrays.

double

alias of float64

dsplit(ary, indices_or_sections)

Split array into multiple sub-arrays along the 3rd axis (depth).

dstack(tup[, dtype])

Stack arrays in sequence depth wise (along third axis).

dtype(dtype[, align, copy])

Create a data type object.

ediff1d(ary[, to_end, to_begin])

The differences between consecutive elements of an array.

einsum(subscripts, /, *operands[, out, ...])

Evaluates the Einstein summation convention on the operands.

einsum_path(subscripts, *operands[, optimize])

Evaluates the lowest cost contraction order for an einsum expression by

empty(shape[, dtype, device])

Return a new array of given shape and type, without initializing entries.

empty_like(prototype[, dtype, shape, device])

Return a new array with the same shape and type as a given array.

equal(x1, x2, /)

Return (x1 == x2) element-wise.

exp(x, /)

Calculate the exponential of all elements in the input array.

exp2(x, /)

Calculate 2**p for all p in the input array.

expand_dims(a, axis)

Expand the shape of an array.

expm1(x, /)

Calculate exp(x) - 1 for all elements in the array.

extract(condition, arr)

Return the elements of an array that satisfy some condition.

eye(N[, M, k, dtype])

Return a 2-D array with ones on the diagonal and zeros elsewhere.

fabs(x, /)

Compute the absolute values element-wise.

fill_diagonal(a, val[, wrap, inplace])

Fill the main diagonal of the given array of any dimensionality.

finfo(dtype)

Machine limits for floating point types.

fix(x[, out])

Round to nearest integer towards zero.

flatnonzero(a, *[, size, fill_value])

Return indices that are non-zero in the flattened version of a.

flexible()

Abstract base class of all scalar types without predefined length.

flip(m[, axis])

Reverse the order of elements in an array along the given axis.

fliplr(m)

Reverse the order of elements along axis 1 (left/right).

flipud(m)

Reverse the order of elements along axis 0 (up/down).

float_

alias of float64

float_power(x1, x2, /)

First array elements raised to powers from second array, element-wise.

float16(x)

float32(x)

float64(x)

floating()

Abstract base class of all floating-point scalar types.

floor(x, /)

Return the floor of the input, element-wise.

floor_divide(x1, x2, /)

Return the largest integer smaller or equal to the division of the inputs.

fmax(x1, x2)

Element-wise maximum of array elements.

fmin(x1, x2)

Element-wise minimum of array elements.

fmod(x1, x2, /)

Returns the element-wise remainder of division.

frexp(x, /)

Decompose the elements of x into mantissa and twos exponent.

frombuffer(buffer[, dtype, count, offset])

Interpret a buffer as a 1-dimensional array.

fromfile(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromfile.

fromfunction(function, shape, *[, dtype])

Construct an array by executing a function over each coordinate.

fromiter(*args, **kwargs)

Unimplemented JAX wrapper for jnp.fromiter.

frompyfunc(func, /, nin, nout, *[, identity])

Create a JAX ufunc from an arbitrary JAX-compatible scalar function.

fromstring(string[, dtype, count])

A new 1-D array initialized from text data in a string.

from_dlpack(x)

Create a NumPy array from an object implementing the __dlpack__

full(shape, fill_value[, dtype, device])

Return a new array of given shape and type, filled with fill_value.

full_like(a, fill_value[, dtype, shape, device])

Return a full array with the same shape and type as a given array.

gcd(x1, x2)

Returns the greatest common divisor of |x1| and |x2|

generic()

Base class for numpy scalar types.

geomspace(start, stop[, num, endpoint, ...])

Return numbers spaced evenly on a log scale (a geometric progression).

get_printoptions()

Return the current print options.

gradient(f, *varargs[, axis, edge_order])

Return the gradient of an N-dimensional array.

greater(x1, x2, /)

Return the truth value of (x1 > x2) element-wise.

greater_equal(x1, x2, /)

Return the truth value of (x1 >= x2) element-wise.

hamming(M)

Return the Hamming window.

hanning(M)

Return the Hanning window.

heaviside(x1, x2, /)

Compute the Heaviside step function.

histogram(a[, bins, range, weights, density])

Compute the histogram of a dataset.

histogram_bin_edges(a[, bins, range, weights])

Function to calculate only the edges of the bins used by the histogram

histogram2d(x, y[, bins, range, weights, ...])

Compute the bi-dimensional histogram of two data samples.

histogramdd(sample[, bins, range, weights, ...])

Compute the multidimensional histogram of some data.

hsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays horizontally (column-wise).

hstack(tup[, dtype])

Stack arrays in sequence horizontally (column wise).

hypot(x1, x2, /)

Given the "legs" of a right triangle, return its hypotenuse.

i0(x)

Modified Bessel function of the first kind, order 0.

identity(n[, dtype])

Return the identity array.

iinfo(int_type)

imag(val, /)

Return the imaginary part of the complex argument.

index_exp

A nicer way to build up index tuples for arrays.

indices(dimensions[, dtype, sparse])

Return an array representing the indices of a grid.

inexact()

Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floating-point numbers.

inner(a, b, *[, precision, ...])

Inner product of two arrays.

insert(arr, obj, values[, axis])

Insert values along the given axis before the given indices.

int_

alias of int64

int16(x)

int32(x)

int64(x)

int8(x)

integer()

Abstract base class of all integer scalar types.

interp(x, xp, fp[, left, right, period])

One-dimensional linear interpolation for monotonically increasing sample points.

intersect1d(ar1, ar2[, assume_unique, ...])

Find the intersection of two arrays.

invert(x, /)

Compute bit-wise inversion, or bit-wise NOT, element-wise.

isclose(a, b[, rtol, atol, equal_nan])

Returns a boolean array where two arrays are element-wise equal within a

iscomplex(x)

Returns a bool array, where True if input element is complex.

iscomplexobj(x)

Check for a complex type or an array of complex numbers.

isdtype(dtype, kind)

Returns a boolean indicating whether a provided dtype is of a specified kind.

isfinite(x, /)

Test element-wise for finiteness (not infinity and not Not a Number).

isin(element, test_elements[, ...])

Calculates element in test_elements, broadcasting over element only.

isinf(x, /)

Test element-wise for positive or negative infinity.

isnan(x, /)

Test element-wise for NaN and return result as a boolean array.

isneginf(x, /[, out])

Test element-wise for negative infinity, return result as bool array.

isposinf(x, /[, out])

Test element-wise for positive infinity, return result as bool array.

isreal(x)

Returns a bool array, where True if input element is real.

isrealobj(x)

Return True if x is a not complex type or an array of complex numbers.

isscalar(element)

Returns True if the type of element is a scalar type.

issubdtype(arg1, arg2)

Returns True if first argument is a typecode lower/equal in type hierarchy.

iterable(y)

Check whether or not an object can be iterated over.

ix_(*args)

Construct an open mesh from multiple sequences.

kaiser(M, beta)

Return the Kaiser window.

kron(a, b)

Kronecker product of two arrays.

lcm(x1, x2)

Returns the lowest common multiple of |x1| and |x2|

ldexp(x1, x2, /)

Returns x1 * 2**x2, element-wise.

left_shift(x1, x2, /)

Shift the bits of an integer to the left.

less(x1, x2, /)

Return the truth value of (x1 < x2) element-wise.

less_equal(x1, x2, /)

Return the truth value of (x1 <= x2) element-wise.

lexsort(keys[, axis])

Perform an indirect stable sort using a sequence of keys.

linspace(start, stop[, num, endpoint, ...])

Return evenly spaced numbers over a specified interval.

load(*args, **kwargs)

Load arrays or pickled objects from .npy, .npz or pickled files.

log(x, /)

Natural logarithm, element-wise.

log10(x, /)

Return the base 10 logarithm of the input array, element-wise.

log1p(x, /)

Return the natural logarithm of one plus the input array, element-wise.

log2(x, /)

Base-2 logarithm of x.

logaddexp(x1, x2, /)

Logarithm of the sum of exponentiations of the inputs.

logaddexp2(x1, x2, /)

Logarithm of the sum of exponentiations of the inputs in base-2.

logical_and(*args)

Compute the truth value of x1 AND x2 element-wise.

logical_not(*args)

Compute the truth value of NOT x element-wise.

logical_or(*args)

Compute the truth value of x1 OR x2 element-wise.

logical_xor(*args)

Compute the truth value of x1 XOR x2, element-wise.

logspace(start, stop[, num, endpoint, base, ...])

Return numbers spaced evenly on a log scale.

mask_indices(*args, **kwargs)

Return the indices to access (n, n) arrays, given a masking function.

matmul(a, b, *[, precision, ...])

Matrix product of two arrays.

matrix_transpose(x, /)

Transposes the last two dimensions of x.

max(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis.

maximum(x1, x2, /)

Element-wise maximum of array elements.

mean(a[, axis, dtype, out, keepdims, where])

Compute the arithmetic mean along the specified axis.

median(a[, axis, out, overwrite_input, keepdims])

Compute the median along the specified axis.

meshgrid(*xi[, copy, sparse, indexing])

Return a list of coordinate matrices from coordinate vectors.

mgrid

Return dense multi-dimensional "meshgrid".

min(a[, axis, out, keepdims, initial, where])

Return the minimum of an array or minimum along an axis.

minimum(x1, x2, /)

Element-wise minimum of array elements.

mod(x1, x2, /)

Returns the element-wise remainder of division.

modf(x, /[, out])

Return the fractional and integral parts of an array, element-wise.

moveaxis(a, source, destination)

Move axes of an array to new positions.

multiply(x1, x2, /)

Multiply arguments element-wise.

nan_to_num(x[, copy, nan, posinf, neginf])

Replace NaN with zero and infinity with large finite numbers (default

nanargmax(a[, axis, out, keepdims])

Return the indices of the maximum values in the specified axis ignoring

nanargmin(a[, axis, out, keepdims])

Return the indices of the minimum values in the specified axis ignoring

nancumprod(a[, axis, dtype, out])

Return the cumulative product of array elements over a given axis treating Not a

nancumsum(a[, axis, dtype, out])

Return the cumulative sum of array elements over a given axis treating Not a

nanmax(a[, axis, out, keepdims, initial, where])

Return the maximum of an array or maximum along an axis, ignoring any

nanmean(a[, axis, dtype, out, keepdims, where])

Compute the arithmetic mean along the specified axis, ignoring NaNs.

nanmedian(a[, axis, out, overwrite_input, ...])

Compute the median along the specified axis, while ignoring NaNs.

nanmin(a[, axis, out, keepdims, initial, where])

Return minimum of an array or minimum along an axis, ignoring any NaNs.

nanpercentile(a, q[, axis, out, ...])

Compute the qth percentile of the data along the specified axis,

nanprod(a[, axis, dtype, out, keepdims, ...])

Return the product of array elements over a given axis treating Not a

nanquantile(a, q[, axis, out, ...])

Compute the qth quantile of the data along the specified axis,

nanstd(a[, axis, dtype, out, ddof, ...])

Compute the standard deviation along the specified axis, while

nansum(a[, axis, dtype, out, keepdims, ...])

Return the sum of array elements over a given axis treating Not a

nanvar(a[, axis, dtype, out, ddof, ...])

Compute the variance along the specified axis, while ignoring NaNs.

ndarray

alias of Array

ndim(a)

Return the number of dimensions of an array.

negative(x, /)

Numerical negative, element-wise.

nextafter(x1, x2, /)

Return the next floating-point value after x1 towards x2, element-wise.

nonzero(a, *[, size, fill_value])

Return the indices of the elements that are non-zero.

not_equal(x1, x2, /)

Return (x1 != x2) element-wise.

number()

Abstract base class of all numeric scalar types.

object_

Any Python object.

ogrid

Return open multi-dimensional "meshgrid".

ones(shape[, dtype, device])

Return a new array of given shape and type, filled with ones.

ones_like(a[, dtype, shape, device])

Return an array of ones with the same shape and type as a given array.

outer(a, b[, out])

Compute the outer product of two vectors.

packbits(a[, axis, bitorder])

Packs the elements of a binary-valued array into bits in a uint8 array.

pad(array, pad_width[, mode])

Pad an array.

partition(a, kth[, axis])

Return a partitioned copy of an array.

percentile(a, q[, axis, out, ...])

Compute the q-th percentile of the data along the specified axis.

permute_dims(a, /, axes)

param a:

piecewise(x, condlist, funclist, *args, **kw)

Evaluate a piecewise-defined function.

place(arr, mask, vals, *[, inplace])

Change elements of an array based on conditional and input values.

poly(seq_of_zeros)

Find the coefficients of a polynomial with the given sequence of roots.

polyadd(a1, a2)

Find the sum of two polynomials.

polyder(p[, m])

Return the derivative of the specified order of a polynomial.

polydiv(u, v, *[, trim_leading_zeros])

Returns the quotient and remainder of polynomial division.

polyfit(x, y, deg[, rcond, full, w, cov])

Least squares polynomial fit.

polyint(p[, m, k])

Return an antiderivative (indefinite integral) of a polynomial.

polymul(a1, a2, *[, trim_leading_zeros])

Find the product of two polynomials.

polysub(a1, a2)

Difference (subtraction) of two polynomials.

polyval(p, x, *[, unroll])

Evaluate a polynomial at specific values.

positive(x, /)

Numerical positive, element-wise.

pow(x1, x2, /)

First array elements raised to powers from second array, element-wise.

power(x1, x2, /)

First array elements raised to powers from second array, element-wise.

printoptions(*args, **kwargs)

Context manager for setting print options.

prod(a[, axis, dtype, out, keepdims, ...])

Return the product of array elements over a given axis.

promote_types(a, b)

Returns the type to which a binary operation should cast its arguments.

ptp(a[, axis, out, keepdims])

Range of values (maximum - minimum) along an axis.

put(a, ind, v[, mode, inplace])

Replaces specified elements of an array with given values.

quantile(a, q[, axis, out, overwrite_input, ...])

Compute the q-th quantile of the data along the specified axis.

r_

Concatenate slices, scalars and array-like objects along the first axis.

rad2deg(x, /)

Convert angles from radians to degrees.

radians(x, /)

Convert angles from degrees to radians.

ravel(a[, order])

Return a contiguous flattened array.

ravel_multi_index(multi_index, dims[, mode, ...])

Converts a tuple of index arrays into an array of flat

real(val, /)

Return the real part of the complex argument.

reciprocal(x, /)

Return the reciprocal of the argument, element-wise.

remainder(x1, x2, /)

Returns the element-wise remainder of division.

repeat(a, repeats[, axis, total_repeat_length])

Repeat each element of an array after themselves

reshape(a, newshape[, order])

Gives a new shape to an array without changing its data.

resize(a, new_shape)

Return a new array with the specified shape.

result_type(*args)

Returns the type that results from applying the NumPy

right_shift(x1, x2, /)

Shift the bits of an integer to the right.

rint(x, /)

Round elements of the array to the nearest integer.

roll(a, shift[, axis])

Roll array elements along a given axis.

rollaxis(a, axis[, start])

Roll the specified axis backwards, until it lies in a given position.

roots(p, *[, strip_zeros])

Return the roots of a polynomial with coefficients given in p.

rot90(m[, k, axes])

Rotate an array by 90 degrees in the plane specified by axes.

round(a[, decimals, out])

Round an array to the given number of decimals.

round_(a[, decimals, out])

Round an array to the given number of decimals.

s_

A nicer way to build up index tuples for arrays.

save(file, arr[, allow_pickle, fix_imports])

Save an array to a binary file in NumPy .npy format.

savez(file, *args, **kwds)

Save several arrays into a single file in uncompressed .npz format.

searchsorted(a, v[, side, sorter, method])

Find indices where elements should be inserted to maintain order.

select(condlist, choicelist[, default])

Return an array drawn from elements in choicelist, depending on conditions.

set_printoptions([precision, threshold, ...])

Set printing options.

setdiff1d(ar1, ar2[, assume_unique, size, ...])

Find the set difference of two arrays.

setxor1d(ar1, ar2[, assume_unique])

Find the set exclusive-or of two arrays.

shape(a)

Return the shape of an array.

sign(x, /)

Returns an element-wise indication of the sign of a number.

signbit(x, /)

Returns element-wise True where signbit is set (less than zero).

signedinteger()

Abstract base class of all signed integer scalar types.

sin(x, /)

Trigonometric sine, element-wise.

sinc(x, /)

Return the normalized sinc function.

single

alias of float32

sinh(x, /)

Hyperbolic sine, element-wise.

size(a[, axis])

Return the number of elements along a given axis.

sort(a[, axis, kind, order, stable, descending])

Return a sorted copy of an array.

sort_complex(a)

Sort a complex array using the real part first, then the imaginary part.

split(ary, indices_or_sections[, axis])

Split an array into multiple sub-arrays as views into ary.

sqrt(x, /)

Return the non-negative square-root of an array, element-wise.

square(x, /)

Return the element-wise square of the input.

squeeze(a[, axis])

Remove axes of length one from a.

stack(arrays[, axis, out, dtype])

Join a sequence of arrays along a new axis.

std(a[, axis, dtype, out, ddof, keepdims, where])

Compute the standard deviation along the specified axis.

subtract(x1, x2, /)

Subtract arguments, element-wise.

sum(a[, axis, dtype, out, keepdims, ...])

Sum of array elements over a given axis.

swapaxes(a, axis1, axis2)

Interchange two axes of an array.

take(a, indices[, axis, out, mode, ...])

Take elements from an array along an axis.

take_along_axis(arr, indices, axis[, mode])

Take values from the input array by matching 1d index and data slices.

tan(x, /)

Compute tangent element-wise.

tanh(x, /)

Compute hyperbolic tangent element-wise.

tensordot(a, b[, axes, precision, ...])

Compute tensor dot product along specified axes.

tile(A, reps)

Construct an array by repeating A the number of times given by reps.

trace(a[, offset, axis1, axis2, dtype, out])

Return the sum along diagonals of the array.

transpose(a[, axes])

Returns an array with axes transposed.

tri(N[, M, k, dtype])

An array with ones at and below the given diagonal and zeros elsewhere.

tril(m[, k])

Lower triangle of an array.

tril_indices(n[, k, m])

Return the indices for the lower-triangle of an (n, m) array.

tril_indices_from(arr[, k])

Return the indices for the lower-triangle of arr.

trim_zeros(filt[, trim])

Trim the leading and/or trailing zeros from a 1-D array or sequence.

triu(m[, k])

Upper triangle of an array.

triu_indices(n[, k, m])

Return the indices for the upper-triangle of an (n, m) array.

triu_indices_from(arr[, k])

Return the indices for the upper-triangle of arr.

true_divide(x1, x2, /)

Divide arguments element-wise.

trunc(x)

Return the truncated value of the input, element-wise.

ufunc(func, /, nin, nout, *[, name, nargs, ...])

Functions that operate element-by-element on whole arrays.

uint

alias of uint64

uint16(x)

uint32(x)

uint64(x)

uint8(x)

union1d(ar1, ar2, *[, size, fill_value])

Find the union of two arrays.

unique(ar[, return_index, return_inverse, ...])

Find the unique elements of an array.

unique_all(x, /)

param x:

unique_counts(x, /)

param x:

unique_inverse(x, /)

param x:

unique_values(x, /)

param x:

unpackbits(a[, axis, count, bitorder])

Unpacks elements of a uint8 array into a binary-valued output array.

unravel_index(indices, shape)

Converts a flat index or array of flat indices into a tuple

unsignedinteger()

Abstract base class of all unsigned integer scalar types.

unwrap(p[, discont, axis, period])

Unwrap by taking the complement of large deltas with respect to the period.

vander(x[, N, increasing])

Generate a Vandermonde matrix.

var(a[, axis, dtype, out, ddof, keepdims, where])

Compute the variance along the specified axis.

vdot(a, b, *[, precision, ...])

Return the dot product of two vectors.

vecdot(x1, x2, /, *[, axis, precision, ...])

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices.

vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

vsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays vertically (row-wise).

vstack(tup[, dtype])

Stack arrays in sequence vertically (row wise).

where([acondition, if_true, if_false, size, ...])

Return elements chosen from x or y depending on condition.

zeros(shape[, dtype, device])

Return a new array of given shape and type, filled with zeros.

zeros_like(a[, dtype, shape, device])

Return an array of zeros with the same shape and type as a given array.

jax.numpy.fft#

fft(a[, n, axis, norm])

Compute the one-dimensional discrete Fourier Transform.

fft2(a[, s, axes, norm])

Compute the 2-dimensional discrete Fourier Transform.

fftfreq(n[, d, dtype])

Return the Discrete Fourier Transform sample frequencies.

fftn(a[, s, axes, norm])

Compute the N-dimensional discrete Fourier Transform.

fftshift(x[, axes])

Shift the zero-frequency component to the center of the spectrum.

hfft(a[, n, axis, norm])

Compute the FFT of a signal that has Hermitian symmetry, i.e., a real

ifft(a[, n, axis, norm])

Compute the one-dimensional inverse discrete Fourier Transform.

ifft2(a[, s, axes, norm])

Compute the 2-dimensional inverse discrete Fourier Transform.

ifftn(a[, s, axes, norm])

Compute the N-dimensional inverse discrete Fourier Transform.

ifftshift(x[, axes])

The inverse of fftshift.

ihfft(a[, n, axis, norm])

Compute the inverse FFT of a signal that has Hermitian symmetry.

irfft(a[, n, axis, norm])

Computes the inverse of rfft.

irfft2(a[, s, axes, norm])

Computes the inverse of rfft2.

irfftn(a[, s, axes, norm])

Computes the inverse of rfftn.

rfft(a[, n, axis, norm])

Compute the one-dimensional discrete Fourier Transform for real input.

rfft2(a[, s, axes, norm])

Compute the 2-dimensional FFT of a real array.

rfftfreq(n[, d, dtype])

Return the Discrete Fourier Transform sample frequencies

rfftn(a[, s, axes, norm])

Compute the N-dimensional discrete Fourier Transform for real input.

jax.numpy.linalg#

cholesky(a, *[, upper])

Cholesky decomposition.

cond(x[, p])

Compute the condition number of a matrix.

cross(x1, x2, /, *[, axis])

param x1:

det(a)

Compute the determinant of an array.

diagonal(x, /, *[, offset])

param x:

eig(a)

Compute the eigenvalues and right eigenvectors of a square array.

eigh(a[, UPLO, symmetrize_input])

Return the eigenvalues and eigenvectors of a complex Hermitian

eigvals(a)

Compute the eigenvalues of a general matrix.

eigvalsh(a[, UPLO])

Compute the eigenvalues of a complex Hermitian or real symmetric matrix.

inv(a)

Compute the (multiplicative) inverse of a matrix.

lstsq(a, b[, rcond, numpy_resid])

Return the least-squares solution to a linear matrix equation.

matmul(x1, x2, /)

param x1:

matrix_norm(x, /, *[, keepdims, ord])

Computes the matrix norm of a matrix (or a stack of matrices) x.

matrix_power(a, n)

Raise a square matrix to the (integer) power n.

matrix_rank(M[, tol])

Return matrix rank of array using SVD method

matrix_transpose(x, /)

Transposes a matrix (or a stack of matrices) x.

multi_dot(arrays, *[, precision])

Compute the dot product of two or more arrays in a single function call,

norm(x[, ord, axis, keepdims])

Matrix or vector norm.

outer(x1, x2, /)

param x1:

pinv(a[, rcond, hermitian])

Compute the (Moore-Penrose) pseudo-inverse of a matrix.

qr(a[, mode])

Compute the qr factorization of a matrix.

slogdet(a, *[, method])

Compute the sign and (natural) logarithm of the determinant of an array.

solve(a, b)

Solve a linear matrix equation, or system of linear scalar equations.

svd(a[, full_matrices, compute_uv, ...])

Singular Value Decomposition.

svdvals(x, /)

param x:

tensordot(x1, x2, /, *[, axes])

param x1:

tensorinv(a[, ind])

Compute the 'inverse' of an N-dimensional array.

tensorsolve(a, b[, axes])

Solve the tensor equation a x = b for x.

vector_norm(x, /, *[, axis, keepdims, ord])

Computes the vector norm of a vector (or batch of vectors) x.

vecdot(x1, x2, /, *[, axis])

param x1:

JAX Array#

The JAX Array (along with its alias, jax.numpy.ndarray) is the core array object in JAX: you can think of it as JAX’s equivalent of a numpy.ndarray. Like numpy.ndarray, most users will not need to instantiate Array objects manually, but rather will create them via jax.numpy functions like array(), arange(), linspace(), and others listed above.

Copying and Serialization#

JAX Array objects are designed to work seamlessly with Python standard library tools where appropriate.

With the built-in copy module, when copy.copy() or copy.deepcopy() encounder an Array, it is equivalent to calling the copy() method, which will create a copy of the buffer on the same device as the original array. This will work correctly within traced/JIT-compiled code, though copy operations may be elided by the compiler in this context.

When the built-in pickle module encounters an Array, it will be serialized via a compact bit representation in a similar manner to pickled numpy.ndarray objects. When unpickled, the result will be a new Array object on the default device. This is because in general, pickling and unpickling may take place in different runtime environments, and there is no general way to map the device IDs of one runtime to the device IDs of another. If pickle is used in traced/JIT-compiled code, it will result in a ConcretizationTypeError.