jax.numpy.trilΒΆ
-
jax.numpy.
tril
(m, k=0)[source]ΒΆ Lower triangle of an array.
LAX-backend implementation of
tril()
. Original docstring below.Return a copy of an array with elements above the k-th diagonal zeroed.
- Parameters
m (array_like, shape (M, N)) β Input array.
k (int, optional) β Diagonal above which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above.
- Returns
tril β Lower triangle of m, of same shape and data-type as m.
- Return type
ndarray, shape (M, N)
See also
triu()
same thing, only for the upper triangle
Examples
>>> np.tril([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) array([[ 0, 0, 0], [ 4, 0, 0], [ 7, 8, 0], [10, 11, 12]])