jax.scipy package¶

jax.scipy.linalg¶

block_diag(*arrs) Create a block diagonal matrix from provided arrays.
cho_factor(a[, lower, overwrite_a, check_finite]) Compute the Cholesky decomposition of a matrix, to use in cho_solve
cho_solve(c_and_lower, b[, overwrite_b, â€¦]) Solve the linear equations A x = b, given the Cholesky factorization of A.
cholesky(a[, lower, overwrite_a, check_finite]) Compute the Cholesky decomposition of a matrix.
det(a[, overwrite_a, check_finite]) Compute the determinant of a matrix
eigh(a[, b, lower, eigvals_only, â€¦]) Solve an ordinary or generalized eigenvalue problem for a complex
expm(A, *[, upper_triangular]) Compute the matrix exponential using Pade approximation.
inv(a[, overwrite_a, check_finite]) Compute the inverse of a matrix.
lu(a[, permute_l, overwrite_a, check_finite]) Compute pivoted LU decomposition of a matrix.
lu_factor(a[, overwrite_a, check_finite]) Compute pivoted LU decomposition of a matrix.
lu_solve(lu_and_piv, b[, trans, â€¦]) Solve an equation system, a x = b, given the LU factorization of a
qr(a[, overwrite_a, lwork, mode, pivoting, â€¦]) Compute QR decomposition of a matrix.
solve(a, b[, sym_pos, lower, overwrite_a, â€¦]) Solves the linear equation set a * x = b for the unknown x
solve_triangular(a, b[, trans, lower, â€¦]) Solve the equation a x = b for x, assuming a is a triangular matrix.
svd(a[, full_matrices, compute_uv, â€¦]) Singular Value Decomposition.
tril(m[, k]) Make a copy of a matrix with elements above the k-th diagonal zeroed.
triu(m[, k]) Make a copy of a matrix with elements below the k-th diagonal zeroed.

jax.scipy.ndimage¶

map_coordinates(input, coordinates, order[, â€¦]) Map the input array to new coordinates by interpolation.

jax.scipy.special¶

betainc(a, b, x) Incomplete beta function.
digamma(x) The digamma function.
entr(x) Elementwise function for computing entropy.
erf(x) Returns the error function of complex argument.
erfc(x) Complementary error function, 1 - erf(x).
erfinv(x) Inverse of the error function erf.
expit Expit (a.k.a.
gammainc(a, x) Regularized lower incomplete gamma function.
gammaincc(a, x) Regularized upper incomplete gamma function.
gammaln(x) Logarithm of the absolute value of the Gamma function.
i0e(x) Exponentially scaled modified Bessel function of order 0.
i1e(x) Exponentially scaled modified Bessel function of order 1.
log_ndtr Log Normal distribution function.
logit Logit ufunc for ndarrays.
logsumexp(a[, axis, b, keepdims, return_sign]) Compute the log of the sum of exponentials of input elements.
multigammaln(a, d) Returns the log of multivariate gamma, also sometimes called the
ndtr(x) Normal distribution function.
ndtri(p) The inverse of the CDF of the Normal distribution function.
xlog1py(x, y) Compute x*log1p(y) so that the result is 0 if x = 0.
xlogy(x, y) Compute x*log(y) so that the result is 0 if x = 0.

jax.scipy.stats¶

jax.scipy.stats.beta¶

logpdf(x, a, b[, loc, scale]) Log of the probability density function at x of the given RV.
pdf(x, a, b[, loc, scale]) Probability density function at x of the given RV.

jax.scipy.stats.expon¶

logpdf(x[, loc, scale]) Log of the probability density function at x of the given RV.
pdf(x[, loc, scale]) Probability density function at x of the given RV.

jax.scipy.stats.gamma¶

logpdf(x, a[, loc, scale]) Log of the probability density function at x of the given RV.
pdf(x, a[, loc, scale]) Probability density function at x of the given RV.

jax.scipy.stats.laplace¶

cdf(x[, loc, scale]) Cumulative distribution function of the given RV.
logpdf(x[, loc, scale]) Log of the probability density function at x of the given RV.
pdf(x[, loc, scale]) Probability density function at x of the given RV.

jax.scipy.stats.logistic¶

cdf(x) Cumulative distribution function of the given RV.
isf(x) Inverse survival function (inverse of sf) at q of the given RV.
logpdf(x) Log of the probability density function at x of the given RV.
pdf(x) Probability density function at x of the given RV.
ppf(x) Percent point function (inverse of cdf) at q of the given RV.
sf(x) Survival function (1 - cdf) at x of the given RV.

jax.scipy.stats.norm¶

cdf(x[, loc, scale]) Cumulative distribution function of the given RV.
logcdf(x[, loc, scale]) Log of the cumulative distribution function at x of the given RV.
logpdf(x[, loc, scale]) Log of the probability density function at x of the given RV.
pdf(x[, loc, scale]) Probability density function at x of the given RV.
ppf(q[, loc, scale]) Percent point function (inverse of cdf) at q of the given RV.

jax.scipy.stats.uniform¶

logpdf(x[, loc, scale]) Log of the probability density function at x of the given RV.
pdf(x[, loc, scale]) Probability density function at x of the given RV.