jax.numpy.linalg.vector_norm

Contents

jax.numpy.linalg.vector_norm#

jax.numpy.linalg.vector_norm(x, /, *, axis=None, keepdims=False, ord=2)[source]#

Computes the vector norm of a vector or batch of vectors.

JAX implementation of numpy.linalg.vector_norm().

Parameters:
  • x (jax.typing.ArrayLike) – N-dimensional array for which to take the norm.

  • axis (int | None) – optional axis along which to compute the vector norm. If None (default) then x is flattened and the norm is taken over all values.

  • keepdims (bool) – if True, keep the reduced dimensions in the output.

  • ord (int | str) – A string or int specifying the type of norm; default is the 2-norm. See numpy.linalg.norm() for details on available options.

Returns:

array containing the norm of x.

Return type:

Array

See also

Examples

Norm of a single vector:

>>> x = jnp.array([1., 2., 3.])
>>> jnp.linalg.vector_norm(x)
Array(3.7416575, dtype=float32)

Norm of a batch of vectors:

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.vector_norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)