jax.scipy module#

jax.scipy.cluster#

vq(obs, code_book[, check_finite])

Assign codes from a code book to a set of observations.

jax.scipy.fft#

dct(x[, type, n, axis, norm])

Computes the discrete cosine transform of the input

dctn(x[, type, s, axes, norm])

Computes the multidimensional discrete cosine transform of the input

idct(x[, type, n, axis, norm])

Computes the inverse discrete cosine transform of the input

idctn(x[, type, s, axes, norm])

Computes the multidimensional inverse discrete cosine transform of the input

jax.scipy.integrate#

trapezoid(y[, x, dx, axis])

Integrate along the given axis using the composite trapezoidal rule.

jax.scipy.interpolate#

RegularGridInterpolator(points, values[, ...])

Interpolate points on a regular rectangular grid.

jax.scipy.linalg#

block_diag(*arrs)

Create a block diagonal matrix from input arrays.

cho_factor(a[, lower, overwrite_a, check_finite])

Factorization for Cholesky-based linear solves

cho_solve(c_and_lower, b[, overwrite_b, ...])

Solve a linear system using a Cholesky factorization

cholesky(a[, lower, overwrite_a, check_finite])

Compute the Cholesky decomposition of a matrix.

det(a[, overwrite_a, check_finite])

Compute the determinant of a matrix

eigh()

Compute eigenvalues and eigenvectors for a Hermitian matrix

eigh_tridiagonal(d, e, *[, eigvals_only, ...])

Solve the eigenvalue problem for a symmetric real tridiagonal matrix

expm(A, *[, upper_triangular, max_squarings])

Compute the matrix exponential

expm_frechet()

Compute the Frechet derivative of the matrix exponential.

funm(A, func[, disp])

Evaluate a matrix-valued function

hessenberg()

Compute the Hessenberg form of the matrix

hilbert(n)

Create a Hilbert matrix of order n.

inv(a[, overwrite_a, check_finite])

Return the inverse of a square matrix

lu()

Compute the LU decomposition

lu_factor(a[, overwrite_a, check_finite])

Factorization for LU-based linear solves

lu_solve(lu_and_piv, b[, trans, ...])

Solve a linear system using an LU factorization

polar(a[, side, method, eps, max_iterations])

Computes the polar decomposition.

qr()

Compute the QR decomposition of an array

rsf2csf(T, Z[, check_finite])

Convert real Schur form to complex Schur form.

schur(a[, output])

Compute the Schur decomposition

solve(a, b[, lower, overwrite_a, ...])

Solve a linear system of equations

solve_triangular(a, b[, trans, lower, ...])

Solve a triangular linear system of equations

sqrtm(A[, blocksize])

Compute the matrix square root

svd()

Compute the singular value decomposition.

toeplitz(c[, r])

Construct a Toeplitz matrix

jax.scipy.ndimage#

map_coordinates(input, coordinates, order[, ...])

Map the input array to new coordinates using interpolation.

jax.scipy.optimize#

minimize(fun, x0[, args, tol, options])

Minimization of scalar function of one or more variables.

OptimizeResults(x, success, status, fun, ...)

Object holding optimization results.

jax.scipy.signal#

fftconvolve(in1, in2[, mode, axes])

Convolve two N-dimensional arrays using Fast Fourier Transform (FFT).

convolve(in1, in2[, mode, method, precision])

Convolution of two N-dimensional arrays.

convolve2d(in1, in2[, mode, boundary, ...])

Convolution of two 2-dimensional arrays.

correlate(in1, in2[, mode, method, precision])

Cross-correlation of two N-dimensional arrays.

correlate2d(in1, in2[, mode, boundary, ...])

Cross-correlation of two 2-dimensional arrays.

csd(x, y[, fs, window, nperseg, noverlap, ...])

Estimate cross power spectral density (CSD) using Welch's method.

detrend(data[, axis, type, bp, overwrite_data])

Remove linear or piecewise linear trends from data.

istft(Zxx[, fs, window, nperseg, noverlap, ...])

Perform the inverse short-time Fourier transform (ISTFT).

stft(x[, fs, window, nperseg, noverlap, ...])

Compute the short-time Fourier transform (STFT).

welch(x[, fs, window, nperseg, noverlap, ...])

Estimate power spectral density (PSD) using Welch's method.

jax.scipy.spatial.transform#

Rotation(quat)

Rotation in 3 dimensions.

Slerp(times, timedelta, rotations, rotvecs)

Spherical Linear Interpolation of Rotations.

jax.scipy.sparse.linalg#

bicgstab(A, b[, x0, tol, atol, maxiter, M])

Use Bi-Conjugate Gradient Stable iteration to solve Ax = b.

cg(A, b[, x0, tol, atol, maxiter, M])

Use Conjugate Gradient iteration to solve Ax = b.

gmres(A, b[, x0, tol, atol, restart, ...])

GMRES solves the linear system A x = b for x, given A and b.

jax.scipy.special#

bernoulli(n)

Generate the first N Bernoulli numbers.

beta(x, y)

The beta function

betainc(a, b, x)

The regularized incomplete beta function.

betaln(a, b)

Natural log of the absolute value of the beta function

digamma(x)

The digamma function

entr(x)

The entropy function

erf(x)

The error function

erfc(x)

The complement of the error function

erfinv(x)

The inverse of the error function

exp1(x)

Exponential integral function.

expi

Exponential integral function.

expit(x)

The logistic sigmoid (expit) function

expn

Generalized exponential integral function.

factorial(n[, exact])

Factorial function

gamma(x)

The gamma function.

gammainc(a, x)

The regularized lower incomplete gamma function.

gammaincc(a, x)

The regularized upper incomplete gamma function.

gammaln(x)

Natural log of the absolute value of the gamma function.

gammasgn(x)

Sign of the gamma function.

hyp1f1

The 1F1 hypergeometric function.

i0(x)

Modified bessel function of zeroth order.

i0e(x)

Exponentially scaled modified bessel function of zeroth order.

i1(x)

Modified bessel function of first order.

i1e(x)

Exponentially scaled modified bessel function of first order.

log_ndtr

Log Normal distribution function.

logit

The logit function

logsumexp()

Log-sum-exp reduction.

lpmn(m, n, z)

The associated Legendre functions (ALFs) of the first kind.

lpmn_values(m, n, z, is_normalized)

The associated Legendre functions (ALFs) of the first kind.

multigammaln(a, d)

The natural log of the multivariate gamma function.

ndtr(x)

Normal distribution function.

ndtri(p)

The inverse of the CDF of the Normal distribution function.

poch

The Pochammer symbol.

polygamma(n, x)

The polygamma function.

spence(x)

Spence's function, also known as the dilogarithm for real values.

sph_harm(m, n, theta, phi[, n_max])

Computes the spherical harmonics.

xlog1py

Compute x*log(1 + y), returning 0 for x=0.

xlogy

Compute x*log(y), returning 0 for x=0.

zeta

The Hurwitz zeta function.

kl_div(p, q)

The Kullback-Leibler divergence.

rel_entr(p, q)

The relative entropy function.

jax.scipy.stats#

mode(a[, axis, nan_policy, keepdims])

Compute the mode (most common value) along an axis of an array.

rankdata(a[, method, axis, nan_policy])

Compute the rank of data along an array axis.

sem(a[, axis, ddof, nan_policy, keepdims])

Compute the standard error of the mean.

jax.scipy.stats.bernoulli#

logpmf(k, p[, loc])

Bernoulli log probability mass function.

pmf(k, p[, loc])

Bernoulli probability mass function.

cdf(k, p)

Bernoulli cumulative distribution function.

ppf(q, p)

Bernoulli percent point function.

jax.scipy.stats.beta#

logpdf(x, a, b[, loc, scale])

Beta log probability distribution function.

pdf(x, a, b[, loc, scale])

Beta probability distribution function.

cdf(x, a, b[, loc, scale])

Beta cumulative distribution function

logcdf(x, a, b[, loc, scale])

Beta log cumulative distribution function.

sf(x, a, b[, loc, scale])

Beta distribution survival function.

logsf(x, a, b[, loc, scale])

Beta distribution log survival function.

jax.scipy.stats.betabinom#

logpmf(k, n, a, b[, loc])

Beta-binomial log probability mass function.

pmf(k, n, a, b[, loc])

Beta-binomial probability mass function.

jax.scipy.stats.binom#

logpmf(k, n, p[, loc])

Binomial log probability mass function.

pmf(k, n, p[, loc])

Binomial probability mass function.

jax.scipy.stats.cauchy#

logpdf(x[, loc, scale])

Cauchy log probability distribution function.

pdf(x[, loc, scale])

Cauchy probability distribution function.

cdf(x[, loc, scale])

Cauchy cumulative distribution function.

logcdf(x[, loc, scale])

Cauchy log cumulative distribution function.

sf(x[, loc, scale])

Cauchy distribution log survival function.

logsf(x[, loc, scale])

Cauchy distribution log survival function.

isf(q[, loc, scale])

Cauchy distribution inverse survival function.

ppf(q[, loc, scale])

Cauchy distribution percent point function.

jax.scipy.stats.chi2#

logpdf(x, df[, loc, scale])

Chi-square log probability distribution function.

pdf(x, df[, loc, scale])

Chi-square probability distribution function.

cdf(x, df[, loc, scale])

Chi-square cumulative distribution function.

logcdf(x, df[, loc, scale])

Chi-square log cumulative distribution function.

sf(x, df[, loc, scale])

Chi-square survival function.

logsf(x, df[, loc, scale])

Chi-square log survival function.

jax.scipy.stats.dirichlet#

logpdf(x, alpha)

Dirichlet log probability distribution function.

pdf(x, alpha)

Dirichlet probability distribution function.

jax.scipy.stats.expon#

logpdf(x[, loc, scale])

Exponential log probability distribution function.

pdf(x[, loc, scale])

Exponential probability distribution function.

jax.scipy.stats.gamma#

logpdf(x, a[, loc, scale])

Gamma log probability distribution function.

pdf(x, a[, loc, scale])

Gamma probability distribution function.

cdf(x, a[, loc, scale])

Gamma cumulative distribution function.

logcdf(x, a[, loc, scale])

Gamma log cumulative distribution function.

sf(x, a[, loc, scale])

Gamma survival function.

logsf(x, a[, loc, scale])

Gamma log survival function.

jax.scipy.stats.gennorm#

cdf(x, beta)

Generalized normal cumulative distribution function.

logpdf(x, beta)

Generalized normal log probability distribution function.

pdf(x, beta)

Generalized normal probability distribution function.

jax.scipy.stats.geom#

logpmf(k, p[, loc])

Geometric log probability mass function.

pmf(k, p[, loc])

Geometric probability mass function.

jax.scipy.stats.laplace#

cdf(x[, loc, scale])

Laplace cumulative distribution function.

logpdf(x[, loc, scale])

Laplace log probability distribution function.

pdf(x[, loc, scale])

Laplace probability distribution function.

jax.scipy.stats.logistic#

cdf(x[, loc, scale])

Logistic cumulative distribution function.

isf(x[, loc, scale])

Logistic distribution inverse survival function.

logpdf(x[, loc, scale])

Logistic log probability distribution function.

pdf(x[, loc, scale])

Logistic probability distribution function.

ppf(x[, loc, scale])

Logistic distribution percent point function.

sf(x[, loc, scale])

Logistic distribution survival function.

jax.scipy.stats.multinomial#

logpmf(x, n, p)

Multinomial log probability mass function.

pmf(x, n, p)

Multinomial probability mass function.

jax.scipy.stats.multivariate_normal#

logpdf(x, mean, cov[, allow_singular])

Multivariate normal log probability distribution function.

pdf(x, mean, cov)

Multivariate normal probability distribution function.

jax.scipy.stats.nbinom#

logpmf(k, n, p[, loc])

Negative-binomial log probability mass function.

pmf(k, n, p[, loc])

Negative-binomial probability mass function.

jax.scipy.stats.norm#

logpdf(x[, loc, scale])

Normal log probability distribution function.

pdf(x[, loc, scale])

Normal probability distribution function.

cdf(x[, loc, scale])

Normal cumulative distribution function.

logcdf(x[, loc, scale])

Normal log cumulative distribution function.

ppf(q[, loc, scale])

Normal distribution percent point function.

sf(x[, loc, scale])

Normal distribution survival function.

logsf(x[, loc, scale])

Normal distribution log survival function.

isf(q[, loc, scale])

Normal distribution inverse survival function.

jax.scipy.stats.pareto#

logpdf(x, b[, loc, scale])

Pareto log probability distribution function.

pdf(x, b[, loc, scale])

Pareto probability distribution function.

jax.scipy.stats.poisson#

logpmf(k, mu[, loc])

Poisson log probability mass function.

pmf(k, mu[, loc])

Poisson probability mass function.

cdf(k, mu[, loc])

Poisson cumulative distribution function.

jax.scipy.stats.t#

logpdf(x, df[, loc, scale])

Student's T log probability distribution function.

pdf(x, df[, loc, scale])

Student's T probability distribution function.

jax.scipy.stats.truncnorm#

cdf(x, a, b[, loc, scale])

Truncated normal cumulative distribution function.

logcdf(x, a, b[, loc, scale])

Truncated normal log cumulative distribution function.

logpdf(x, a, b[, loc, scale])

Truncated normal log probability distribution function.

logsf(x, a, b[, loc, scale])

Truncated normal distribution log survival function.

pdf(x, a, b[, loc, scale])

Truncated normal probability distribution function.

sf(x, a, b[, loc, scale])

Truncated normal distribution log survival function.

jax.scipy.stats.uniform#

logpdf(x[, loc, scale])

Uniform log probability distribution function.

pdf(x[, loc, scale])

Uniform probability distribution function.

cdf(x[, loc, scale])

Uniform cumulative distribution function.

ppf(q[, loc, scale])

Uniform distribution percent point function.

jax.scipy.stats.gaussian_kde#

gaussian_kde(dataset[, bw_method, weights])

Gaussian Kernel Density Estimator

gaussian_kde.evaluate(points)

Evaluate the Gaussian KDE on the given points.

gaussian_kde.integrate_gaussian(mean, cov)

Integrate the distribution weighted by a Gaussian.

gaussian_kde.integrate_box_1d(low, high)

Integrate the distribution over the given limits.

gaussian_kde.integrate_kde(other)

Integrate the product of two Gaussian KDE distributions.

gaussian_kde.resample(key[, shape])

Randomly sample a dataset from the estimated pdf

gaussian_kde.pdf(x)

Probability density function

gaussian_kde.logpdf(x)

Log probability density function

jax.scipy.stats.vonmises#

logpdf(x, kappa)

von Mises log probability distribution function.

pdf(x, kappa)

von Mises probability distribution function.

jax.scipy.stats.wrapcauchy#

logpdf(x, c)

Wrapped Cauchy log probability distribution function.

pdf(x, c)

Wrapped Cauchy probability distribution function.