jax.numpy.tril_indices
jax.numpy.tril_indices#
- jax.numpy.tril_indices(*args, **kwargs)#
Return the indices for the lower-triangle of an (n, m) array.
LAX-backend implementation of
numpy.tril_indices()
.Original docstring below.
- Parameters
n (int) – The row dimension of the arrays for which the returned indices will be valid.
k (int, optional) – Diagonal offset (see tril for details).
m (int, optional) –
New in version 1.9.0.
The column dimension of the arrays for which the returned arrays will be valid. By default m is taken equal to n.
- Returns
inds – The indices for the triangle. The returned tuple contains two arrays, each with the indices along one dimension of the array.
- Return type
tuple of arrays