jax.numpy.triu_indices

jax.numpy.triu_indices(*args, **kwargs)

Return the indices for the upper-triangle of an (n, m) array.

LAX-backend implementation of triu_indices(). Original docstring below.

nint

The size of the arrays for which the returned indices will be valid.

kint, optional

Diagonal offset (see triu for details).

mint, optional

New in version 1.9.0.

The column dimension of the arrays for which the returned arrays will be valid. By default m is taken equal to n.

indstuple, shape(2) of ndarrays, shape(n)

The indices for the triangle. The returned tuple contains two arrays, each with the indices along one dimension of the array. Can be used to slice a ndarray of shape(n, n).

tril_indices : similar function, for lower-triangular. mask_indices : generic function accepting an arbitrary mask function. triu, tril

New in version 1.4.0.

Compute two different sets of indices to access 4x4 arrays, one for the upper triangular part starting at the main diagonal, and one starting two diagonals further right:

>>> iu1 = np.triu_indices(4)
>>> iu2 = np.triu_indices(4, 2)

Here is how they can be used with a sample array:

>>> a = np.arange(16).reshape(4, 4)
>>> a
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

Both for indexing:

>>> a[iu1]
array([ 0,  1,  2, ..., 10, 11, 15])

And for assigning values:

>>> a[iu1] = -1
>>> a
array([[-1, -1, -1, -1],
       [ 4, -1, -1, -1],
       [ 8,  9, -1, -1],
       [12, 13, 14, -1]])

These cover only a small part of the whole array (two diagonals right of the main one):

>>> a[iu2] = -10
>>> a
array([[ -1,  -1, -10, -10],
       [  4,  -1,  -1, -10],
       [  8,   9,  -1,  -1],
       [ 12,  13,  14,  -1]])