jax.numpy.linalg.pinv

Contents

jax.numpy.linalg.pinv#

jax.numpy.linalg.pinv(a, rtol=None, hermitian=False, *, rcond=Deprecated)[source]#

Compute the (Moore-Penrose) pseudo-inverse of a matrix.

JAX implementation of numpy.linalg.pinv().

Parameters:
  • a (ArrayLike) – array of shape (..., M, N) containing matrices to pseudo-invert.

  • rtol (ArrayLike | None) – float or array_like of shape a.shape[:-2]. Specifies the cutoff for small singular values.of shape (...,). Cutoff for small singular values; singular values smaller rtol * largest_singular_value are treated as zero. The default is determined based on the floating point precision of the dtype.

  • hermitian (bool) – if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False)

  • rcond (ArrayLike | DeprecatedArg | None)

Returns:

An array of shape (..., N, M) containing the pseudo-inverse of a.

Return type:

Array

See also

Notes

jax.numpy.linalg.prng() differs from numpy.linalg.prng() in the default value of rcond`: in NumPy, the default is 1e-15. In JAX, the default is 10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps.

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4],
...                [5, 6]])
>>> a_pinv = jnp.linalg.pinv(a)
>>> a_pinv  
Array([[-1.333332  , -0.33333257,  0.6666657 ],
       [ 1.0833322 ,  0.33333272, -0.41666582]], dtype=float32)

The pseudo-inverse operates as a multiplicative inverse so long as the output is not rank-deficient:

>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4)
Array(True, dtype=bool)