jax.numpy.fill_diagonal

Contents

jax.numpy.fill_diagonal#

jax.numpy.fill_diagonal(a, val, wrap=False, *, inplace=True)[source]#

Fill the main diagonal of the given array of any dimensionality.

LAX-backend implementation of numpy.fill_diagonal().

The semantics of numpy.fill_diagonal() is to modify arrays in-place, which JAX cannot do because JAX arrays are immutable. Thus jax.numpy.fill_diagonal() adds the inplace parameter, which must be set to False by the user as a reminder of this API difference.

Original docstring below.

For an array a with a.ndim >= 2, the diagonal is the list of locations with indices a[i, ..., i] all identical. This function modifies the input array in-place, it does not return a value.

Parameters:
  • a (array, at least 2-D.) – Array whose diagonal is to be filled, it gets modified in-place.

  • val (scalar or array_like) – Value(s) to write on the diagonal. If val is scalar, the value is written along the diagonal. If array-like, the flattened val is written along the diagonal, repeating if necessary to fill all diagonal entries.

  • wrap (bool) – For tall matrices in NumPy version up to 1.6.2, the diagonal “wrapped” after N columns. You can have this behavior with this option. This affects only tall matrices.

  • inplace (bool, default=True) – If left to its default value of True, JAX will raise an error. This is because the semantics of numpy.fill_diagonal() are to modify the array in-place, which is not possible in JAX due to the immutability of JAX arrays.

Return type:

Array