jax.numpy.polyval

Contents

jax.numpy.polyval#

jax.numpy.polyval(p, x, *, unroll=16)#

Evaluate a polynomial at specific values.

LAX-backend implementation of numpy.polyval().

The unroll parameter is JAX specific. It does not effect correctness but can have a major impact on performance for evaluating high-order polynomials. The parameter controls the number of unrolled steps with lax.scan inside the polyval implementation. Consider setting unroll=128 (or even higher) to improve runtime performance on accelerators, at the cost of increased compilation time.

Original docstring below.

Note

This forms part of the old polynomial API. Since version 1.4, the new polynomial API defined in numpy.polynomial is preferred. A summary of the differences can be found in the transition guide.

If p is of length N, this function returns the value:

p[0]*x**(N-1) + p[1]*x**(N-2) + ... + p[N-2]*x + p[N-1]

If x is a sequence, then p(x) is returned for each element of x. If x is another polynomial then the composite polynomial p(x(t)) is returned.

Parameters:
  • p (array_like or poly1d object) – 1D array of polynomial coefficients (including coefficients equal to zero) from highest degree to the constant term, or an instance of poly1d.

  • x (array_like or poly1d object) – A number, an array of numbers, or an instance of poly1d, at which to evaluate p.

  • unroll (int)

Returns:

values – If x is a poly1d instance, the result is the composition of the two polynomials, i.e., x is “substituted” in p and the simplified result is returned. In addition, the type of x - array_like or poly1d - governs the type of the output: x array_like => values array_like, x a poly1d object => values is also.

Return type:

ndarray or poly1d

References