jax.numpy.polyval#
- jax.numpy.polyval(p, x, *, unroll=16)[source]#
Evaluates the polynomial at specific values.
JAX implementations of
numpy.polyval()
.For the 1D-polynomial coefficients
p
of lengthM
, the function returns the value:\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]- Parameters:
p (ArrayLike) – An array of polynomial coefficients of shape
(M,)
.x (ArrayLike) – A number or an array of numbers.
unroll (int) – A number used to control the number of unrolled steps with
lax.scan
. It must be specified statically.
- Returns:
An array of same shape as
x
.- Return type:
Note
The
unroll
parameter is JAX specific. It does not affect correctness but can have a major impact on performance for evaluating high-order polynomials. The parameter controls the number of unrolled steps withlax.scan
inside thejnp.polyval
implementation. Consider settingunroll=128
(or even higher) to improve runtime performance on accelerators, at the cost of increased compilation time.See also
jax.numpy.polyfit()
: Least squares polynomial fit.jax.numpy.poly()
: Finds the coefficients of a polynomial with given roots.jax.numpy.roots()
: Computes the roots of a polynomial for given coefficients.
Examples
>>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32)
If
x
is a 2D array,polyval
returns 2D-array with same shape as that ofx
:>>> x = jnp.array([[2, 1, 5], ... [3, 4, 7], ... [1, 3, 5]]) >>> jnp.polyval(p, x) Array([[ 19., 8., 76.], [ 34., 53., 134.], [ 8., 34., 76.]], dtype=float32)