jax.numpy.diagflat#
- jax.numpy.diagflat(v, k=0)[source]#
Create a two-dimensional array with the flattened input as a diagonal.
LAX-backend implementation of
numpy.diagflat()
.This differs from np.diagflat for some scalar values of v, jax always returns a two-dimensional array, whereas numpy may return a scalar depending on the type of v.
Original docstring below.
- Parameters:
v (array_like) – Input data, which is flattened and set as the k-th diagonal of the output.
k (int, optional) – Diagonal to set; 0, the default, corresponds to the “main” diagonal, a positive (negative) k giving the number of the diagonal above (below) the main.
- Returns:
out – The 2-D output array.
- Return type:
ndarray