jax.numpy.linalg.vector_norm

Contents

jax.numpy.linalg.vector_norm#

jax.numpy.linalg.vector_norm(x, /, *, axis=None, keepdims=False, ord=2)[source]#

Computes the vector norm of a vector (or batch of vectors) x.

Parameters:
  • x (jax.typing.ArrayLike)

  • axis (int | None)

  • keepdims (bool)

  • ord (int | str)

Return type:

Array