jax.numpy.linalg.slogdet

Contents

jax.numpy.linalg.slogdet#

jax.numpy.linalg.slogdet(a, *, method=None)[source]#

Computes the sign and (natural) logarithm of the determinant of an array.

JAX implementation of numpy.linalg.slotdet().

Parameters:
  • a (jax.typing.ArrayLike) – array of shape (..., M, M) for which to compute the sign and log determinant.

  • method (str | None) –

    the method to use for determinant computation. Options are

    • 'lu' (default): use the LU decomposition.

    • 'qr': use the QR decomposition.

Returns:

A tuple of arrays (sign, logabsdet), each of shape a.shape[:-2]

  • sign is the sign of the determinant.

  • logabsdet is the natural log of the determinant’s absolute value.

Return type:

SlogdetResult

See also

jax.numpy.linalg.det(): direct computation of determinant

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> sign, logabsdet = jnp.linalg.slogdet(a)
>>> sign  # -1 indicates negative determinant
Array(-1., dtype=float32)
>>> jnp.exp(logabsdet)  # Absolute value of determinant
Array(2., dtype=float32)