jax.numpy.linalg.cond

Contents

jax.numpy.linalg.cond#

jax.numpy.linalg.cond(x, p=None)[source]#

Compute the condition number of a matrix.

JAX implementation of numpy.linalg.cond().

The condition number is defined as norm(x, p) * norm(inv(x), p). For p = 2 (the default), the condition number is the ratio of the largest to the smallest singular value.

Parameters:
  • x (jax.typing.ArrayLike) – array of shape (..., M, N) for which to compute the condition number.

  • p – the order of the norm to use. One of {None, 1, -1, 2, -2, inf, -inf, 'fro'}; see jax.numpy.linalg.norm() for the meaning of these. The default is p = None, which is equivalent to p = 2. If not in {None, 2, -2} then x must be square, i.e. M = N.

Returns:

array of shape x.shape[:-2] containing the condition number.

Examples

Well-conditioned matrix:

>>> x = jnp.array([[1, 2],
...                [2, 1]])
>>> jnp.linalg.cond(x)
Array(3., dtype=float32)

Ill-conditioned matrix:

>>> x = jnp.array([[1, 2],
...                [0, 0]])
>>> jnp.linalg.cond(x)
Array(inf, dtype=float32)