jax.scipy.linalg.det

Contents

jax.scipy.linalg.det#

jax.scipy.linalg.det(a, overwrite_a=False, check_finite=True)[source]#

Compute the determinant of a matrix

JAX implementation of scipy.linalg.det().

Parameters:
  • a (jax.typing.ArrayLike) – input array, of shape (..., N, N)

  • overwrite_a (bool) – unused by JAX

  • check_finite (bool) – unused by JAX

Return type:

Array

Returns

Determinant of shape a.shape[:-2]

See also

jax.numpy.linalg.det(): NumPy-style determinant API

Examples

Determinant of a small 2D array:

>>> x = jnp.array([[1., 2.],
...                [3., 4.]])
>>> jax.scipy.linalg.det(x)
Array(-2., dtype=float32)

Batch-wise determinant of multiple 2D arrays:

>>> x = jnp.array([[[1., 2.],
...                 [3., 4.]],
...                [[8., 5.],
...                 [7., 9.]]])
>>> jax.scipy.linalg.det(x)
Array([-2., 37.], dtype=float32)