jax.numpy.tril_indicesΒΆ

jax.numpy.tril_indices(*args, **kwargs)ΒΆ

Return the indices for the lower-triangle of an (n, m) array.

LAX-backend implementation of tril_indices().

Original docstring below.

Parameters
  • n (int) – The row dimension of the arrays for which the returned indices will be valid.

  • k (int, optional) – Diagonal offset (see tril for details).

  • m (int, optional) –

    New in version 1.9.0.

    The column dimension of the arrays for which the returned arrays will be valid. By default m is taken equal to n.

Returns

inds – The indices for the triangle. The returned tuple contains two arrays, each with the indices along one dimension of the array.

Return type

tuple of arrays