jax.numpy.roots

jax.numpy.roots(p, *, strip_zeros=True)[source]

Return the roots of a polynomial with coefficients given in p.

LAX-backend implementation of roots().

If the input polynomial coefficients of length n do not start with zero, the polynomial is of degree n - 1 leading to n - 1 roots. If the coefficients do have leading zeros, the polynomial they define has a smaller degree and the number of roots (and thus the output shape) is value dependent.

The general implementation can therefore not be transformed with jit. If the coefficients are guaranteed to have no leading zeros, use the keyword argument strip_zeros=False to get a jit-compatible variant:

>>> from functools import partial
>>> roots_unsafe = jax.jit(partial(jnp.roots, strip_zeros=False))
>>> roots_unsafe([1, 2])     # ok
DeviceArray([-2.+0.j], dtype=complex64)
>>> roots_unsafe([0, 1, 2])  # problem
DeviceArray([nan+nanj, nan+nanj], dtype=complex64)
>>> jnp.roots([0, 1, 2])     # use the no-jit version instead
DeviceArray([-2.+0.j], dtype=complex64)

Original docstring below.

The values in the rank-1 array p are coefficients of a polynomial. If the length of p is n+1 then the polynomial is described by:

p[0] * x**n + p[1] * x**(n-1) + ... + p[n-1]*x + p[n]
Parameters

p (array_like) – Rank-1 array of polynomial coefficients.

Returns

out – An array containing the roots of the polynomial.

Return type

ndarray

References

1

R. A. Horn & C. R. Johnson, Matrix Analysis. Cambridge, UK: Cambridge University Press, 1999, pp. 146-7.