jax.numpy.triu

Contents

jax.numpy.triu#

jax.numpy.triu(m, k=0)[source]#

Upper triangle of an array.

LAX-backend implementation of numpy.triu().

Original docstring below.

Return a copy of an array with the elements below the k-th diagonal zeroed. For arrays with ndim exceeding 2, triu will apply to the final two axes.

Please refer to the documentation for tril for further details.

Parameters:
  • m (jax.typing.ArrayLike)

  • k (int)

Return type:

Array