jax.scipy package#

jax.scipy.fft#

dct(x[, type, n, axis, norm])

Return the Discrete Cosine Transform of arbitrary type sequence x.

dctn(x[, type, s, axes, norm])

Return multidimensional Discrete Cosine Transform along the specified axes.

jax.scipy.linalg#

block_diag(*arrs)

Create a block diagonal matrix from provided arrays.

cho_factor(a[, lower, overwrite_a, check_finite])

Compute the Cholesky decomposition of a matrix, to use in cho_solve

cho_solve(c_and_lower, b[, overwrite_b, ...])

Solve the linear equations A x = b, given the Cholesky factorization of A.

cholesky(a[, lower, overwrite_a, check_finite])

Compute the Cholesky decomposition of a matrix.

det(a[, overwrite_a, check_finite])

Compute the determinant of a matrix

eigh(a[, b, lower, eigvals_only, ...])

Solve a standard or generalized eigenvalue problem for a complex

eigh_tridiagonal(d, e, *[, eigvals_only, ...])

Solve eigenvalue problem for a real symmetric tridiagonal matrix.

expm(A, *[, upper_triangular, max_squarings])

Compute the matrix exponential of an array.

expm_frechet(A, E, *[, method, compute_expm])

Frechet derivative of the matrix exponential of A in the direction E.

funm(A, func[, disp])

Evaluate a matrix function specified by a callable.

inv(a[, overwrite_a, check_finite])

Compute the inverse of a matrix.

lu(a[, permute_l, overwrite_a, check_finite])

Compute pivoted LU decomposition of a matrix.

lu_factor(a[, overwrite_a, check_finite])

Compute pivoted LU decomposition of a matrix.

lu_solve(lu_and_piv, b[, trans, ...])

Solve an equation system, a x = b, given the LU factorization of a

polar(a[, side, method, eps, max_iterations])

Computes the polar decomposition.

polar_unitary(a, *[, method, eps, ...])

Computes the unitary factor u in the polar decomposition a = u p (or a = p u).

qr(a[, overwrite_a, lwork, mode, pivoting, ...])

Compute QR decomposition of a matrix.

rsf2csf(T, Z[, check_finite])

Convert real Schur form to complex Schur form.

schur(a[, output])

Compute Schur decomposition of a matrix.

sqrtm(A[, blocksize])

Matrix square root.

solve(a, b[, sym_pos, lower, overwrite_a, ...])

Solves the linear equation set a * x = b for the unknown x

solve_triangular(a, b[, trans, lower, ...])

Solve the equation a x = b for x, assuming a is a triangular matrix.

sqrtm(A[, blocksize])

Matrix square root.

svd(a[, full_matrices, compute_uv, ...])

Singular Value Decomposition.

tril(m[, k])

Make a copy of a matrix with elements above the kth diagonal zeroed.

triu(m[, k])

Make a copy of a matrix with elements below the kth diagonal zeroed.

jax.scipy.ndimage#

map_coordinates(input, coordinates, order[, ...])

Map the input array to new coordinates by interpolation.

jax.scipy.optimize#

minimize(fun, x0[, args, tol, options])

Minimization of scalar function of one or more variables.

OptimizeResults(x, success, status, fun, ...)

Object holding optimization results.

jax.scipy.signal#

convolve(in1, in2[, mode, method, precision])

Convolve two N-dimensional arrays.

convolve2d(in1, in2[, mode, boundary, ...])

Convolve two 2-dimensional arrays.

correlate(in1, in2[, mode, method, precision])

Cross-correlate two N-dimensional arrays.

correlate2d(in1, in2[, mode, boundary, ...])

Cross-correlate two 2-dimensional arrays.

csd(x, y[, fs, window, nperseg, noverlap, ...])

Estimate the cross power spectral density, Pxy, using Welch's method.

istft(Zxx[, fs, window, nperseg, noverlap, ...])

Perform the inverse Short Time Fourier transform (iSTFT).

stft(x[, fs, window, nperseg, noverlap, ...])

Compute the Short Time Fourier Transform (STFT).

welch(x[, fs, window, nperseg, noverlap, ...])

Estimate power spectral density using Welch's method.

jax.scipy.sparse.linalg#

bicgstab(A, b[, x0, tol, atol, maxiter, M])

Use Bi-Conjugate Gradient Stable iteration to solve Ax = b.

cg(A, b[, x0, tol, atol, maxiter, M])

Use Conjugate Gradient iteration to solve Ax = b.

gmres(A, b[, x0, tol, atol, restart, ...])

GMRES solves the linear system A x = b for x, given A and b.

jax.scipy.special#

betainc(a, b, x)

Incomplete beta function.

digamma(x)

The digamma function.

entr(x)

Elementwise function for computing entropy.

erf(x)

Returns the error function of complex argument.

erfc(x)

Complementary error function, 1 - erf(x).

erfinv(x)

Inverse of the error function.

exp1(x[, module])

Exponential integral E1.

expi

Exponential integral Ei.

expit(x)

Expit (a.k.a.

expn

Generalized exponential integral En.

gammainc(a, x)

Regularized lower incomplete gamma function.

gammaincc(a, x)

Regularized upper incomplete gamma function.

gammaln(x)

Logarithm of the absolute value of the gamma function.

i0(x)

Modified Bessel function of order 0.

i0e(x)

Exponentially scaled modified Bessel function of order 0.

i1(x)

Modified Bessel function of order 1.

i1e(x)

Exponentially scaled modified Bessel function of order 1.

log_ndtr

Log Normal distribution function.

logit

Logit ufunc for ndarrays.

logsumexp(a[, axis, b, keepdims, return_sign])

Compute the log of the sum of exponentials of input elements.

lpmn(m, n, z)

The associated Legendre functions (ALFs) of the first kind.

lpmn_values(m, n, z, is_normalized)

The associated Legendre functions (ALFs) of the first kind.

multigammaln(a, d)

Returns the log of multivariate gamma, also sometimes called the

ndtr(x)

Normal distribution function.

ndtri(p)

The inverse of the CDF of the Normal distribution function.

polygamma(n, x)

Polygamma functions.

sph_harm(m, n, theta, phi[, n_max])

Computes the spherical harmonics.

xlog1py(x, y)

Compute x*log1p(y) so that the result is 0 if x = 0.

xlogy(x, y)

Compute x*log(y) so that the result is 0 if x = 0.

zeta(x[, q])

Riemann or Hurwitz zeta function.

jax.scipy.stats#

jax.scipy.stats.bernoulli#

logpmf(k, p[, loc])

Log of the probability mass function at k of the given RV.

pmf(k, p[, loc])

Probability mass function at k of the given RV.

jax.scipy.stats.beta#

logpdf(x, a, b[, loc, scale])

Log of the probability density function at x of the given RV.

pdf(x, a, b[, loc, scale])

Probability density function at x of the given RV.

jax.scipy.stats.betabinom#

logpmf(k, n, a, b[, loc])

Log of the probability mass function at k of the given RV.

pmf(k, n, a, b[, loc])

Probability mass function at k of the given RV.

jax.scipy.stats.cauchy#

logpdf(x[, loc, scale])

Log of the probability density function at x of the given RV.

pdf(x[, loc, scale])

Probability density function at x of the given RV.

jax.scipy.stats.chi2#

logpdf(x, df[, loc, scale])

Log of the probability density function at x of the given RV.

pdf(x, df[, loc, scale])

Probability density function at x of the given RV.

jax.scipy.stats.dirichlet#

logpdf(x, alpha)

Log of the Dirichlet probability density function.

pdf(x, alpha)

The Dirichlet probability density function.

jax.scipy.stats.expon#

logpdf(x[, loc, scale])

Log of the probability density function at x of the given RV.

pdf(x[, loc, scale])

Probability density function at x of the given RV.

jax.scipy.stats.gamma#

logpdf(x, a[, loc, scale])

Log of the probability density function at x of the given RV.

pdf(x, a[, loc, scale])

Probability density function at x of the given RV.

jax.scipy.stats.gennorm#

cdf(x, p)

Cumulative distribution function of the given RV.

logpdf(x, p)

Log of the probability density function at x of the given RV.

pdf(x, p)

Probability density function at x of the given RV.

jax.scipy.stats.geom#

logpmf(k, p[, loc])

Log of the probability mass function at k of the given RV.

pmf(k, p[, loc])

Probability mass function at k of the given RV.

jax.scipy.stats.laplace#

cdf(x[, loc, scale])

Cumulative distribution function of the given RV.

logpdf(x[, loc, scale])

Log of the probability density function at x of the given RV.

pdf(x[, loc, scale])

Probability density function at x of the given RV.

jax.scipy.stats.logistic#

cdf(x)

Cumulative distribution function of the given RV.

isf(x)

Inverse survival function (inverse of sf) at q of the given RV.

logpdf(x)

Log of the probability density function at x of the given RV.

pdf(x)

Probability density function at x of the given RV.

ppf(x)

Percent point function (inverse of cdf) at q of the given RV.

sf(x)

Survival function (1 - cdf) at x of the given RV.

jax.scipy.stats.multivariate_normal#

logpdf(x, mean, cov[, allow_singular])

Log of the multivariate normal probability density function.

pdf(x, mean, cov)

Multivariate normal probability density function.

jax.scipy.stats.norm#

cdf(x[, loc, scale])

Cumulative distribution function of the given RV.

logcdf(x[, loc, scale])

Log of the cumulative distribution function at x of the given RV.

logpdf(x[, loc, scale])

Log of the probability density function at x of the given RV.

pdf(x[, loc, scale])

Probability density function at x of the given RV.

ppf(q[, loc, scale])

Percent point function (inverse of cdf) at q of the given RV.

jax.scipy.stats.pareto#

logpdf(x, b[, loc, scale])

Log of the probability density function at x of the given RV.

pdf(x, b[, loc, scale])

Probability density function at x of the given RV.

jax.scipy.stats.poisson#

logpmf(k, mu[, loc])

Log of the probability mass function at k of the given RV.

pmf(k, mu[, loc])

Probability mass function at k of the given RV.

jax.scipy.stats.t#

logpdf(x, df[, loc, scale])

Log of the probability density function at x of the given RV.

pdf(x, df[, loc, scale])

Probability density function at x of the given RV.

jax.scipy.stats.uniform#

logpdf(x[, loc, scale])

Log of the probability density function at x of the given RV.

pdf(x[, loc, scale])

Probability density function at x of the given RV.

jax.scipy.stats.gaussian_kde#

gaussian_kde(dataset[, bw_method, weights])

Representation of a kernel-density estimate using Gaussian kernels.

gaussian_kde.evaluate(points)

Evaluate the estimated pdf on a set of points.

gaussian_kde.integrate_gaussian(mean, cov)

Multiply estimated density by a multivariate Gaussian and integrate

gaussian_kde.integrate_box_1d(low, high)

Computes the integral of a 1D pdf between two bounds.

gaussian_kde.integrate_kde(other)

Computes the integral of the product of this kernel density estimate

gaussian_kde.resample(key[, shape])

Randomly sample a dataset from the estimated pdf

gaussian_kde.pdf(x)

Evaluate the estimated pdf on a provided set of points.

gaussian_kde.logpdf(x)

Evaluate the log of the estimated pdf on a provided set of points.