jax.scipy.linalg.lu_factor#
- jax.scipy.linalg.lu_factor(a, overwrite_a=False, check_finite=True)[source]#
Factorization for LU-based linear solves
JAX implementation of
scipy.linalg.lu_factor()
.This function returns a result suitable for use with
jax.scipy.linalg.lu_solve()
. For direct LU decompositions, preferjax.scipy.linalg.lu()
.- Parameters:
- Returns:
A tuple
(lu, piv)
lu
is an array of shape(..., M, N)
, containingL
in its lower triangle andU
in its upper.piv
is an array of shape(..., K)
withK = min(M, N)
, which encodes the pivots.
- Return type:
Examples
Solving a small linear system via LU factorization:
>>> a = jnp.array([[2., 1.], ... [1., 2.]])
Compute the lu factorization via
lu_factor()
, and use it to solve a linear equation vialu_solve()
.>>> b = jnp.array([3., 4.]) >>> lufac = jax.scipy.linalg.lu_factor(a) >>> y = jax.scipy.linalg.lu_solve(lufac, b) >>> y Array([0.6666666, 1.6666667], dtype=float32)
Check that the result is consistent:
>>> jnp.allclose(a @ y, b) Array(True, dtype=bool)