jax.numpy.diag

Contents

jax.numpy.diag#

jax.numpy.diag(v, k=0)[source]#

Extract a diagonal or construct a diagonal array.

LAX-backend implementation of numpy.diag().

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

See the more detailed documentation for numpy.diagonal if you use this function to extract a diagonal and wish to write to the resulting array; whether it returns a copy or a view depends on what version of numpy you are using.

Parameters:
  • v (array_like) – If v is a 2-D array, return a copy of its k-th diagonal. If v is a 1-D array, return a 2-D array with v on the k-th diagonal.

  • k (int, optional) – Diagonal in question. The default is 0. Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal.

Returns:

out – The extracted diagonal or constructed diagonal array.

Return type:

ndarray