jax.numpy.linalg.matrix_rank

Contents

jax.numpy.linalg.matrix_rank#

jax.numpy.linalg.matrix_rank(M, rtol=None, *, tol=Deprecated)[source]#

Compute the rank of a matrix.

JAX implementation of numpy.linalg.matrix_rank().

The rank is calculated via the Singular Value Decomposition (SVD), and determined by the number of singular values greater than the specified tolerance.

Parameters:
  • a – array of shape (..., M, N) whose rank is to be computed.

  • rtol (ArrayLike | None) – optional array of shape (...) specifying the tolerance. Singular values smaller than rtol * largest_singular_value are considered to be zero. If rtol is None (the default), a reasonable default is chosen based the floating point precision of the input.

  • M (ArrayLike)

  • tol (ArrayLike | DeprecatedArg | None)

Returns:

array of shape a.shape[-2] giving the matrix rank.

Return type:

Array

Notes

The rank calculation may be inaccurate for matrices with very small singular values or those that are numerically ill-conditioned. Consider adjusting the rtol parameter or using a more specialized rank computation method in such cases.

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.linalg.matrix_rank(a)
Array(2, dtype=int32)
>>> b = jnp.array([[1, 0],  # Rank-deficient matrix
...                [0, 0]])
>>> jnp.linalg.matrix_rank(b)
Array(1, dtype=int32)