jax.numpy.linalg.det

Contents

jax.numpy.linalg.det#

jax.numpy.linalg.det = <jax._src.custom_derivatives.custom_jvp object>[source]#

Computes the determinant of an array.

JAX implementation of numpy.linalg.det().

Parameters:

a (jax.typing.ArrayLike) – array of shape (..., M, M) for which to compute the determinant.

Returns:

An array of determinants of shape a.shape[:-2].

Return type:

Array

See also

jax.scipy.linalg.det(): Scipy-style API for determinant.

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.linalg.det(a)
Array(-2., dtype=float32)