jax.numpy.linalg.matrix_norm

Contents

jax.numpy.linalg.matrix_norm#

jax.numpy.linalg.matrix_norm(x, /, *, keepdims=False, ord='fro')[source]#

Compute the norm of a matrix or stack of matrices.

JAX implementation of numpy.linalg.matrix_norm()

Parameters:
  • x (jax.typing.ArrayLike) – array of shape (..., M, N) for which to take the norm.

  • keepdims (bool) – if True, keep the reduced dimensions in the output.

  • ord (str) – A string or int specifying the type of norm; default is the Frobenius norm. See numpy.linalg.norm() for details on available options.

Returns:

array containing the norm of x. Has shape x.shape[:-2] if keepdims is False, or shape (..., 1, 1) if keepdims is True.

Return type:

Array

See also

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> jnp.linalg.matrix_norm(x)
Array(16.881943, dtype=float32)