jax.numpy.diagΒΆ

jax.numpy.diag(v, k=0)[source]ΒΆ

Extract a diagonal or construct a diagonal array.

LAX-backend implementation of diag().

Original docstring below.

See the more detailed documentation for numpy.diagonal if you use this function to extract a diagonal and wish to write to the resulting array; whether it returns a copy or a view depends on what version of numpy you are using.

Parameters
  • v (array_like) – If v is a 2-D array, return a copy of its k-th diagonal. If v is a 1-D array, return a 2-D array with v on the k-th diagonal.

  • k (int, optional) – Diagonal in question. The default is 0. Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal.

Returns

out – The extracted diagonal or constructed diagonal array.

Return type

ndarray