jax.numpy.triu_indices

Contents

jax.numpy.triu_indices#

jax.numpy.triu_indices(n, k=0, m=None)[source]#

Return the indices for the upper-triangle of an (n, m) array.

LAX-backend implementation of numpy.triu_indices().

Original docstring below.

Parameters:
  • n (int) – The size of the arrays for which the returned indices will be valid.

  • k (int, optional) – Diagonal offset (see triu for details).

  • m (int, optional)

Returns:

inds – The indices for the triangle. The returned tuple contains two arrays, each with the indices along one dimension of the array. Can be used to slice a ndarray of shape(n, n).

Return type:

tuple, shape(2) of ndarrays, shape(n)