jax.numpy.tril_indices

Contents

jax.numpy.tril_indices#

jax.numpy.tril_indices(n, k=0, m=None)[source]#

Return the indices of lower triangle of an array of size (n, m).

JAX implementation of numpy.tril_indices().

Parameters:
  • n (int) – int. Number of rows of the array for which the indices are returned.

  • k (int) – optional, int, default=0. Specifies the sub-diagonal on and below which the indices of lower triangle are returned. k=0 refers to main diagonal, k<0 refers to sub-diagonal below the main diagonal and k>0 refers to sub-diagonal above the main diagonal.

  • m (int | None | None) – optional, int. Number of columns of the array for which the indices are returned. If not specified, then m = n.

Returns:

A tuple of two arrays containing the indices of the lower triangle, one along each axis.

Return type:

tuple[Array, Array]

See also

Examples

If only n is provided in input, the indices of lower triangle of an array of size (n, n) array are returned.

>>> jnp.tril_indices(3)
(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))

If both n and m are provided in input, the indices of lower triangle of an (n, m) array are returned.

>>> jnp.tril_indices(3, m=2)
(Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1], dtype=int32))

If k = 1, the indices on and below the first sub-diagonal above the main diagonal are returned.

>>> jnp.tril_indices(3, k=1)
(Array([0, 0, 1, 1, 1, 2, 2, 2], dtype=int32), Array([0, 1, 0, 1, 2, 0, 1, 2], dtype=int32))

If k = -1, the indices on and below the first sub-diagonal below the main diagonal are returned.

>>> jnp.tril_indices(3, k=-1)
(Array([1, 2, 2], dtype=int32), Array([0, 0, 1], dtype=int32))