jax.numpy.fill_diagonal#
- jax.numpy.fill_diagonal(a, val, wrap=False, *, inplace=True)[source]#
Fill the main diagonal of the given array of any dimensionality.
LAX-backend implementation of
numpy.fill_diagonal()
.The semantics of
numpy.fill_diagonal()
is to modify arrays in-place, which JAX cannot do because JAX arrays are immutable. Thusjax.numpy.fill_diagonal()
adds theinplace
parameter, which must be set toFalse
by the user as a reminder of this API difference.Original docstring below.
For an array a with
a.ndim >= 2
, the diagonal is the list of locations with indicesa[i, ..., i]
all identical. This function modifies the input array in-place, it does not return a value.- Parameters:
a (array, at least 2-D.) – Array whose diagonal is to be filled, it gets modified in-place.
val (scalar or array_like) – Value(s) to write on the diagonal. If val is scalar, the value is written along the diagonal. If array-like, the flattened val is written along the diagonal, repeating if necessary to fill all diagonal entries.
wrap (bool) – For tall matrices in NumPy version up to 1.6.2, the diagonal “wrapped” after N columns. You can have this behavior with this option. This affects only tall matrices.
inplace (bool, default=True) – If left to its default value of True, JAX will raise an error. This is because the semantics of
numpy.fill_diagonal()
are to modify the array in-place, which is not possible in JAX due to the immutability of JAX arrays.
- Return type: