jax.numpy.diagonal#

jax.numpy.diagonal(a, offset=0, axis1=0, axis2=1)[source]#

Returns the specified diagonal of an array.

JAX implementation of numpy.diagonal().

The JAX version always returns a copy of the input, although if this is used within a JIT compilation, the compiler may avoid the copy.

Parameters:
  • a (ArrayLike) – Input array. Must be at least 2-dimensional.

  • offset (int) – optional, default=0. Diagonal offset from the main diagonal. Must be a static integer value. Can be positive or negative.

  • axis1 (int) – optional, default=0. The first axis along which to take the diagonal.

  • axis2 (int) –

    optional, default=1. The second axis along which to take the diagonal.

    Returns:

    A 1D array for 2D input, and in general a N-1 dimensional array for N-dimensional input.

Return type:

Array

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6],
...                [7, 8, 9]])
>>> jnp.diagonal(x)
Array([1, 5, 9], dtype=int32)
>>> jnp.diagonal(x, offset=1)
Array([2, 6], dtype=int32)
>>> jnp.diagonal(x, offset=-1)
Array([4, 8], dtype=int32)