jax.numpy.diag_indices_from#
- jax.numpy.diag_indices_from(arr)[source]#
Return the indices to access the main diagonal of an n-dimensional array.
LAX-backend implementation of
numpy.diag_indices_from()
.Original docstring below.
See diag_indices for full details.
- Parameters:
arr (array, at least 2-D) –