- jax.numpy.diag_indices(n, ndim=2)#
Return the indices to access the main diagonal of an array.
LAX-backend implementation of
Original docstring below.
This returns a tuple of indices that can be used to access the main diagonal of an array a with
a.ndim >= 2dimensions and shape (n, n, …, n). For
a.ndim = 2this is the usual diagonal, for
a.ndim > 2this is the set of indices to access
a[i, i, ..., i]for
i = [0..n-1].