jax.numpy.polyval#

jax.numpy.polyval(p, x, *, unroll=16)[source]#

Evaluates the polynomial at specific values.

JAX implementations of numpy.polyval().

For the 1D-polynomial coefficients p of length M, the function returns the value:

\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]
Parameters:
  • p (ArrayLike) – An array of polynomial coefficients of shape (M,).

  • x (ArrayLike) – A number or an array of numbers.

  • unroll (int) – A number used to control the number of unrolled steps with lax.scan. It must be specified statically.

Returns:

An array of same shape as x.

Return type:

Array

Note

The unroll parameter is JAX specific. It does not affect correctness but can have a major impact on performance for evaluating high-order polynomials. The parameter controls the number of unrolled steps with lax.scan inside the jnp.polyval implementation. Consider setting unroll=128 (or even higher) to improve runtime performance on accelerators, at the cost of increased compilation time.

See also

Examples

>>> p = jnp.array([2, 5, 1])
>>> jnp.polyval(p, 3)
Array(34., dtype=float32)

If x is a 2D array, polyval returns 2D-array with same shape as that of x:

>>> x = jnp.array([[2, 1, 5],
...                [3, 4, 7],
...                [1, 3, 5]])
>>> jnp.polyval(p, x)
Array([[ 19.,   8.,  76.],
       [ 34.,  53., 134.],
       [  8.,  34.,  76.]], dtype=float32)