jax.numpy.tril_indicesΒΆ

jax.numpy.tril_indices(*args, **kwargs)ΒΆ

Return the indices for the lower-triangle of an (n, m) array.

LAX-backend implementation of tril_indices(). Original docstring below.

Parameters
  • n (int) – The row dimension of the arrays for which the returned indices will be valid.

  • k (int, optional) – Diagonal offset (see tril for details).

  • m (int, optional) –

    New in version 1.9.0.

  • column dimension of the arrays for which the returned (The) –

  • will be valid. (arrays) –

  • default m is taken equal to n. (By) –

Returns

inds – The indices for the triangle. The returned tuple contains two arrays, each with the indices along one dimension of the array.

Return type

tuple of arrays

See also

triu_indices()

similar function, for upper-triangular.

mask_indices()

generic function accepting an arbitrary mask function.

tril(), triu()

Notes

New in version 1.4.0.

Examples

Compute two different sets of indices to access 4x4 arrays, one for the lower triangular part starting at the main diagonal, and one starting two diagonals further right:

>>> il1 = np.tril_indices(4)
>>> il2 = np.tril_indices(4, 2)

Here is how they can be used with a sample array:

>>> a = np.arange(16).reshape(4, 4)
>>> a
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

Both for indexing:

>>> a[il1]
array([ 0,  4,  5, ..., 13, 14, 15])

And for assigning values:

>>> a[il1] = -1
>>> a
array([[-1,  1,  2,  3],
       [-1, -1,  6,  7],
       [-1, -1, -1, 11],
       [-1, -1, -1, -1]])

These cover almost the whole array (two diagonals right of the main one):

>>> a[il2] = -10
>>> a
array([[-10, -10, -10,   3],
       [-10, -10, -10, -10],
       [-10, -10, -10, -10],
       [-10, -10, -10, -10]])