jax.numpy.tril#
- jax.numpy.tril(m, k=0)[source]#
Lower triangle of an array.
LAX-backend implementation of
numpy.tril()
.Original docstring below.
Return a copy of an array with elements above the k-th diagonal zeroed. For arrays with
ndim
exceeding 2, tril will apply to the final two axes.- Parameters:
m (array_like, shape (..., M, N)) – Input array.
k (int, optional) – Diagonal above which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above.
- Returns:
tril – Lower triangle of m, of same shape and data-type as m.
- Return type:
ndarray, shape (…, M, N)