jax.numpy.roots#
- jax.numpy.roots(p, *, strip_zeros=True)[source]#
Return the roots of a polynomial with coefficients given in p.
LAX-backend implementation of
numpy.roots()
.Unlike the numpy version of this function, the JAX version returns the roots in a complex array regardless of the values of the roots. Additionally, the jax version of this function adds the
strip_zeros
function which must be set to False for the function to be compatible with JIT and other JAX transformations. Withstrip_zeros=False
, if your coefficients have leading zeros, the roots will be padded with NaN values:>>> coeffs = jnp.array([0, 1, 2])
# The default behavior matches numpy and strips leading zeros: >>> jnp.roots(coeffs) Array([-2.+0.j], dtype=complex64)
# With strip_zeros=False, extra roots are set to NaN: >>> jnp.roots(coeffs, strip_zeros=False) Array([-2. +0.j, nan+nanj], dtype=complex64)
Original docstring below.
Note
This forms part of the old polynomial API. Since version 1.4, the new polynomial API defined in numpy.polynomial is preferred. A summary of the differences can be found in the transition guide.
The values in the rank-1 array p are coefficients of a polynomial. If the length of p is n+1 then the polynomial is described by:
p[0] * x**n + p[1] * x**(n-1) + ... + p[n-1]*x + p[n]
- Parameters:
p (array_like) – Rank-1 array of polynomial coefficients.
strip_zeros (bool, default=True) – If set to True, then leading zeros in the coefficients will be stripped, similar to
numpy.roots()
. If set to False, leading zeros will not be stripped, and undefined roots will be represented by NaN values in the function output.strip_zeros
must be set toFalse
for the function to be compatible withjax.jit()
and other JAX transformations.
- Returns:
out – An array containing the roots of the polynomial.
- Return type:
ndarray
References