jax.numpy.diag#
- jax.numpy.diag(v, k=0)[source]#
Extract a diagonal or construct a diagonal array.
LAX-backend implementation of
numpy.diag()
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
See the more detailed documentation for
numpy.diagonal
if you use this function to extract a diagonal and wish to write to the resulting array; whether it returns a copy or a view depends on what version of numpy you are using.- Parameters:
v (array_like) – If v is a 2-D array, return a copy of its k-th diagonal. If v is a 1-D array, return a 2-D array with v on the k-th diagonal.
k (int, optional) – Diagonal in question. The default is 0. Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal.
- Returns:
out – The extracted diagonal or constructed diagonal array.
- Return type:
ndarray