jax.numpy.tril_indices

Contents

jax.numpy.tril_indices#

jax.numpy.tril_indices(n, k=0, m=None)[source]#

Return the indices for the lower-triangle of an (n, m) array.

LAX-backend implementation of numpy.tril_indices().

Original docstring below.

Parameters:
  • n (int) – The row dimension of the arrays for which the returned indices will be valid.

  • k (int, optional) – Diagonal offset (see tril for details).

  • m (int, optional)

Returns:

inds – The indices for the triangle. The returned tuple contains two arrays, each with the indices along one dimension of the array.

Return type:

tuple of arrays