jax.numpy.triu_indicesΒΆ

jax.numpy.triu_indices(*args, **kwargs)ΒΆ

Return the indices for the upper-triangle of an (n, m) array.

LAX-backend implementation of triu_indices().

Original docstring below.

Parameters
  • n (int) – The size of the arrays for which the returned indices will be valid.

  • k (int, optional) – Diagonal offset (see triu for details).

  • m (int, optional) –

    New in version 1.9.0.

    The column dimension of the arrays for which the returned arrays will be valid. By default m is taken equal to n.

Returns

inds – The indices for the triangle. The returned tuple contains two arrays, each with the indices along one dimension of the array. Can be used to slice a ndarray of shape(n, n).

Return type

tuple, shape(2) of ndarrays, shape(n)