jax.numpy.tri(N, M=None, k=0, dtype=None)[source]#

An array with ones at and below the given diagonal and zeros elsewhere.

LAX-backend implementation of numpy.tri().

Original docstring below.

  • N (int) – Number of rows in the array.

  • M (int, optional) – Number of columns in the array. By default, M is taken equal to N.

  • k (int, optional) – The sub-diagonal at and below which the array is filled. k = 0 is the main diagonal, while k < 0 is below it, and k > 0 is above. The default is 0.

  • dtype (dtype, optional) – Data type of the returned array. The default is float.


tri – Array with its lower triangle filled with ones and zero elsewhere; in other words T[i,j] == 1 for j <= i + k, 0 otherwise.

Return type:

ndarray of shape (N, M)