jax.numpy.roots(p, *, strip_zeros=True)[source]#

Return the roots of a polynomial with coefficients given in p.

LAX-backend implementation of numpy.roots().

Unlike the numpy version of this function, the JAX version returns the roots in a complex array regardless of the values of the roots. Additionally, the jax version of this function adds the strip_zeros function which must be set to False for the function to be compatible with JIT and other JAX transformations. With strip_zeros=False, if your coefficients have leading zeros, the roots will be padded with NaN values:

>>> coeffs = jnp.array([0, 1, 2])

# The default behavior matches numpy and strips leading zeros: >>> jnp.roots(coeffs) Array([-2.+0.j], dtype=complex64)

# With strip_zeros=False, extra roots are set to NaN: >>> jnp.roots(coeffs, strip_zeros=False) Array([-2. +0.j, nan+nanj], dtype=complex64)

Original docstring below.


This forms part of the old polynomial API. Since version 1.4, the new polynomial API defined in numpy.polynomial is preferred. A summary of the differences can be found in the transition guide.

The values in the rank-1 array p are coefficients of a polynomial. If the length of p is n+1 then the polynomial is described by:

p[0] * x**n + p[1] * x**(n-1) + ... + p[n-1]*x + p[n]
  • p (array_like) – Rank-1 array of polynomial coefficients.

  • strip_zeros (bool, default=True) – If set to True, then leading zeros in the coefficients will be stripped, similar to numpy.roots(). If set to False, leading zeros will not be stripped, and undefined roots will be represented by NaN values in the function output. strip_zeros must be set to False for the function to be compatible with jax.jit() and other JAX transformations.


out – An array containing the roots of the polynomial.

Return type: