jax.numpy.diag_indices

jax.numpy.diag_indices(n, ndim=2)[source]

Return the indices to access the main diagonal of an array.

LAX-backend implementation of diag_indices(). Original docstring below.

This returns a tuple of indices that can be used to access the main diagonal of an array a with a.ndim >= 2 dimensions and shape (n, n, …, n). For a.ndim = 2 this is the usual diagonal, for a.ndim > 2 this is the set of indices to access a[i, i, ..., i] for i = [0..n-1].

Parameters
  • n (int) –

  • ndim (int, optional)) –