jax.numpy.triΒΆ
-
jax.numpy.
tri
(N, M=None, k=0, dtype=None)[source]ΒΆ An array with ones at and below the given diagonal and zeros elsewhere.
LAX-backend implementation of
tri()
. Original docstring below.- Parameters
N (int) β Number of rows in the array.
M (int, optional) β Number of columns in the array. By default, M is taken equal to N.
k (int, optional) β The sub-diagonal at and below which the array is filled. k = 0 is the main diagonal, while k < 0 is below it, and k > 0 is above. The default is 0.
dtype (dtype, optional) β Data type of the returned array. The default is float.
- Returns
tri β Array with its lower triangle filled with ones and zero elsewhere; in other words
T[i,j] == 1
forj <= i + k
, 0 otherwise.- Return type
ndarray of shape (N, M)
Examples
>>> np.tri(3, 5, 2, dtype=int) array([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]])
>>> np.tri(3, 5, -1) array([[0., 0., 0., 0., 0.], [1., 0., 0., 0., 0.], [1., 1., 0., 0., 0.]])