jax.numpy.linalg.diagonal

Contents

jax.numpy.linalg.diagonal#

jax.numpy.linalg.diagonal(x, /, *, offset=0)[source]#

Extract the diagonal of an matrix or stack of matrices.

JAX implementation of numpy.linalg.diagonal().

Parameters:
  • x (jax.typing.ArrayLike) – array of shape (..., M, N) from which the diagonal will be extracted.

  • offset (int) – positive or negative offset from the main diagonal.

Returns:

Array of shape (..., K) where K is the length of the specified diagonal.

Return type:

Array

See also

Examples

Diagonals of a single matrix:

>>> x = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9, 10, 11, 12]])
>>> jnp.linalg.diagonal(x)
Array([ 1,  6, 11], dtype=int32)
>>> jnp.linalg.diagonal(x, offset=1)
Array([ 2,  7, 12], dtype=int32)
>>> jnp.linalg.diagonal(x, offset=-1)
Array([ 5, 10], dtype=int32)

Batched diagonals:

>>> x = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.linalg.diagonal(x)
Array([[ 0,  5, 10],
       [12, 17, 22]], dtype=int32)