jax.scipy.linalg.lu_factor

Contents

jax.scipy.linalg.lu_factor#

jax.scipy.linalg.lu_factor(a, overwrite_a=False, check_finite=True)[source]#

Factorization for LU-based linear solves

JAX implementation of scipy.linalg.lu_factor().

This function returns a result suitable for use with jax.scipy.linalg.lu_solve(). For direct LU decompositions, prefer jax.scipy.linalg.lu().

Parameters:
  • a (jax.typing.ArrayLike) – input array of shape (..., M, N).

  • overwrite_a (bool) – unused by JAX

  • check_finite (bool) – unused by JAX

Returns:

A tuple (lu, piv)

  • lu is an array of shape (..., M, N), containing L in its lower triangle and U in its upper.

  • piv is an array of shape (..., K) with K = min(M, N), which encodes the pivots.

Return type:

tuple[Array, Array]

Example

Solving a small linear system via LU factorization:

>>> a = jnp.array([[2., 1.],
...                [1., 2.]])

Compute the lu factorization via lu_factor(), and use it to solve a linear equation via lu_solve().

>>> b = jnp.array([3., 4.])
>>> lufac = jax.scipy.linalg.lu_factor(a)
>>> y = jax.scipy.linalg.lu_solve(lufac, b)
>>> y
Array([0.6666666, 1.6666667], dtype=float32)

Check that the result is consistent:

>>> jnp.allclose(a @ y, b)
Array(True, dtype=bool)