jax.numpy.diag_indices¶
-
jax.numpy.
diag_indices
(n, ndim=2)[source]¶ Return the indices to access the main diagonal of an array.
LAX-backend implementation of
diag_indices()
. Original docstring below.This returns a tuple of indices that can be used to access the main diagonal of an array a with
a.ndim >= 2
dimensions and shape (n, n, …, n). Fora.ndim = 2
this is the usual diagonal, fora.ndim > 2
this is the set of indices to accessa[i, i, ..., i]
fori = [0..n-1]
.