# jax.scipy.linalg.lu_factor#

jax.scipy.linalg.lu_factor(a, overwrite_a=False, check_finite=True)[source]#

Factorization for LU-based linear solves

JAX implementation of `scipy.linalg.lu_factor()`.

This function returns a result suitable for use with `jax.scipy.linalg.lu_solve()`. For direct LU decompositions, prefer `jax.scipy.linalg.lu()`.

Parameters:
• a (jax.typing.ArrayLike) â€“ input array of shape `(..., M, N)`.

• overwrite_a (bool) â€“ unused by JAX

• check_finite (bool) â€“ unused by JAX

Returns:

A tuple `(lu, piv)`

• `lu` is an array of shape `(..., M, N)`, containing `L` in its lower triangle and `U` in its upper.

• `piv` is an array of shape `(..., K)` with `K = min(M, N)`, which encodes the pivots.

Return type:

Example

Solving a small linear system via LU factorization:

```>>> a = jnp.array([[2., 1.],
...                [1., 2.]])
```

Compute the lu factorization via `lu_factor()`, and use it to solve a linear equation via `lu_solve()`.

```>>> b = jnp.array([3., 4.])
>>> lufac = jax.scipy.linalg.lu_factor(a)
>>> y = jax.scipy.linalg.lu_solve(lufac, b)
>>> y
Array([0.6666666, 1.6666667], dtype=float32)
```

Check that the result is consistent:

```>>> jnp.allclose(a @ y, b)
Array(True, dtype=bool)
```