jax.numpy.triu_indicesΒΆ

jax.numpy.triu_indices(*args, **kwargs)ΒΆ

Return the indices for the upper-triangle of an (n, m) array.

LAX-backend implementation of triu_indices(). Original docstring below.

Parameters
  • n (int) – The size of the arrays for which the returned indices will be valid.

  • k (int, optional) – Diagonal offset (see triu for details).

  • m (int, optional) –

    New in version 1.9.0.

  • column dimension of the arrays for which the returned (The) –

  • will be valid. (arrays) –

  • default m is taken equal to n. (By) –

Returns

inds – The indices for the triangle. The returned tuple contains two arrays, each with the indices along one dimension of the array. Can be used to slice a ndarray of shape(n, n).

Return type

tuple, shape(2) of ndarrays, shape(n)

See also

tril_indices()

similar function, for lower-triangular.

mask_indices()

generic function accepting an arbitrary mask function.

triu(), tril()

Notes

New in version 1.4.0.

Examples

Compute two different sets of indices to access 4x4 arrays, one for the upper triangular part starting at the main diagonal, and one starting two diagonals further right:

>>> iu1 = np.triu_indices(4)
>>> iu2 = np.triu_indices(4, 2)

Here is how they can be used with a sample array:

>>> a = np.arange(16).reshape(4, 4)
>>> a
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

Both for indexing:

>>> a[iu1]
array([ 0,  1,  2, ..., 10, 11, 15])

And for assigning values:

>>> a[iu1] = -1
>>> a
array([[-1, -1, -1, -1],
       [ 4, -1, -1, -1],
       [ 8,  9, -1, -1],
       [12, 13, 14, -1]])

These cover only a small part of the whole array (two diagonals right of the main one):

>>> a[iu2] = -10
>>> a
array([[ -1,  -1, -10, -10],
       [  4,  -1,  -1, -10],
       [  8,   9,  -1,  -1],
       [ 12,  13,  14,  -1]])