jax.numpy.tril_indices_from

jax.numpy.tril_indices_from#

jax.numpy.tril_indices_from(arr, k=0)[source]#

Return the indices for the lower-triangle of arr.

LAX-backend implementation of numpy.tril_indices_from().

Original docstring below.

See tril_indices for full details.

Parameters:
  • arr (array_like) – The indices will be valid for square arrays whose dimensions are the same as arr.

  • k (int, optional) – Diagonal offset (see tril for details).

Return type:

tuple[Array, Array]